[Sync] Update to latest code (#2679)

* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-03 15:43:53 +08:00
committed by GitHub
parent d222248d00
commit 05c670e593
95 changed files with 9916 additions and 1312 deletions

View File

@@ -166,7 +166,7 @@ function build_and_install() {
echo -e "${BLUE}[install]${NONE} installing fastdeploy..." echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
cd $DIST_DIR cd $DIST_DIR
${python} -m pip install ./dist/fastdeploy*.whl --force-reinstall --no-cache-dir find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
cd .. cd ..
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed" echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
@@ -228,6 +228,9 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
${BLUE}fastdeploy branch:${NONE} $EFFLLM_BRANCH ($EFFLLM_COMMIT)\n" ${BLUE}fastdeploy branch:${NONE} $EFFLLM_BRANCH ($EFFLLM_COMMIT)\n"
echo -e "${GREEN}wheel saved under${NONE} ${RED}${BOLD}./dist${NONE}" echo -e "${GREEN}wheel saved under${NONE} ${RED}${BOLD}./dist${NONE}"
# install wheel
${python} -m pip install ./dist/fastdeploy*.whl
echo -e "${GREEN}wheel install success${NONE}\n" echo -e "${GREEN}wheel install success${NONE}\n"
trap : 0 trap : 0

View File

@@ -0,0 +1,236 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "multi_head_latent_attention_kernel.h"
template <size_t vec_size, typename T>
struct softmax_state_t {
AlignedVector<T, vec_size> o;
T m;
T d;
__device__ __forceinline__ void init() {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = __float2half(-5e4f);
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = __float2bfloat16(-3.38953e38f);
}
}
__device__ __forceinline__ softmax_state_t() {
init();
}
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
T other_m,
T other_d) {
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
T m_prev = m, d_prev = d;
m = m_prev > other_m ? m_prev : other_m;
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
d = d_prev * scale1 + other_d * scale2;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * scale1 + other_o[i] * scale2;
}
}
__device__ __forceinline__ void normalize() {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d;
}
}
};
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
struct softmax_state_ts {
uint32_t num_tiles_ = num_tiles;
AlignedVector<T, vec_size> o[num_tiles];
float m;
float d;
__device__ __forceinline__ void init() {
#pragma unroll
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
}
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = -3.38953e38f;
}
}
__device__ __forceinline__ softmax_state_ts() {
init();
}
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
#pragma unroll
for (size_t i = 0; i < vec_size; i++) {
o[tile_id][i] /= d;
}
}
};
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
__device__ __forceinline__ void produce_kv(CacheT *smem,
CacheT *kv_base_gptr,
const int * block_table_smem,
const uint32_t seq_offset_gmem,
const uint32_t seq_offset_smem,
const uint32_t kv_head_idx,
const uint32_t kv_num_heads,
const uint32_t tidx,
const uint32_t chunk_start,
const uint32_t chunk_end) {
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
if (block_id < 0) {
block_id = 0;
}
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
// 8/16 T/int8 each time
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
seq_offset_gmem < chunk_end
);
}
}
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
const CacheT* k_smem,
const uint32_t kv_idx_base,
const uint32_t stage_idx,
const uint32_t iter_base,
const uint32_t iter_bound,
const uint32_t tidx,
const uint32_t gid,
const float scale,
float *s,
softmax_state_ts<vec_size, T, num_tile_v>& st) {
const CacheT* smem;
AlignedVector<T, vec_size> q_vec;
AlignedVector<T, vec_size> k_vec;
float m_prev = st.m;
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
#pragma unroll
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
if (iter_base + j < iter_bound) {
if constexpr (std::is_same<T, half>::value) {
s[j] = 0.f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = 0.f;
}
#pragma unroll
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
for (uint32_t i = 0; i < vec_size; ++i) {
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
}
}
#pragma unroll
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
}
__syncthreads();
} else {
if constexpr (std::is_same<T, half>::value) {
s[j] = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = -3.38953e38f;
}
}
st.m = st.m > s[j] ? st.m : s[j];
}
// T o_scale = hexp(m_prev - st.m);
float o_scale = __expf(m_prev - st.m);
st.d *= o_scale;
#pragma unroll
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
// s[j] = hexp(s[j] - st.m);
s[j] = __expf(s[j] - st.m);
st.d += s[j];
}
#pragma unroll
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
for (uint32_t i = 0; i < vec_size; ++i) {
st.o[tile_id][i] *= o_scale;
}
}
}
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
__device__ __forceinline__ void compute_sv(const float *s,
const CacheT *base_v_smem,
const uint32_t stage_idx,
const uint32_t iter_base,
const uint32_t iter_bound,
const uint32_t tidx,
softmax_state_ts<vec_size, T, num_tile>& st) {
const CacheT* v_smem;
AlignedVector<T, vec_size> v_vec;
#pragma unroll
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
uint32_t tile_id = vid / bdx;
#pragma unroll
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
}
}
}
}

View File

@@ -0,0 +1,560 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decode_attention_func.cuh"
#define CHECK(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line %d:\n", __LINE__); \
printf(" Error code:%d\n", error_code); \
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
}while(0)
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cum_offsets,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
const float in_scale,
const int num_chunks,
const int chunk_size,
const int max_seq_len,
const int num_heads,
const int head_dim) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int qid = blockIdx.x, hid = blockIdx.y;
const int seq_len_q = seq_lens_q[qid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[qid];
if (seq_len_kv == 0) return;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
return;
}
__shared__ T smem[bdy * HEAD_DIM];
__shared__ T md_smem[bdy * 2];
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
T m;
T d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = __float2half(-5e4f);
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = __float2bfloat16(-3.38953e38f);
}
// merge per ty
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
T m_prev = m;
T d_prev = d;
const T m_now = multi_m[offset];
const T d_now = multi_d[offset];
m = m_prev > m_now ? m_prev : m_now;
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
d = d * scale1 + d_now * scale2;
#pragma once
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
// merge bdy
softmax_state_t<vec_size, T> st{};
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
#pragma once
for (int i = 0; i < iter_num; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
AlignedVector<OutT, vec_size> out_vec;
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
out_vec[i] = static_cast<OutT>(st.o[i]);
}
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
}
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
CacheT * __restrict__ cache_v,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cum_offsets,
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
const float scale,
const float in_scale,
const uint32_t chunk_size,
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
OutT * __restrict__ out) {
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t bid = bidx, gid = threadIdx.y;
const uint32_t tidx = threadIdx.x;
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
const uint32_t kv_num_heads = gridDim.z;
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
const int *block_table_now = block_table + bid * max_block_num_per_seq;
const uint32_t num_chunks = gridDim.y;
const uint32_t chunk_id = blockIdx.y;
const uint32_t q_len = seq_lens_q[bid];
if (q_len <= 0) {
return;
}
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
if (kv_len <= 0) {
return;
}
kv_len += q_len;
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (chunk_id >= num_chunk_this_seq) {
return;
}
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
const uint32_t chunk_len = chunk_end - chunk_start;
extern __shared__ uint8_t smem[];
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
#pragma unroll
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
}
__syncthreads();
using VecT = AlignedVector<T, VEC_SIZE>;
VecT q_vec;
#pragma unroll
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
q_vec[i] *= scale;
}
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
}
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
uint32_t stage_idx = 0;
constexpr int loop_times = DEAL_EACH_TIME / bdy;
#pragma unroll
for (int i = 0; i < NUM_STAGES; ++i) {
#pragma unroll
for (int j = 0; j < loop_times; ++j) {
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
const uint32_t k_seq_id = chunk_start + k_seq_offset;
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
kv_smem,
cache_k,
block_table_now,
k_seq_id,
k_seq_offset,
kv_head_idx,
kv_num_heads,
tidx,
chunk_start,
chunk_end
);
}
commit_group();
stage_idx = (stage_idx + 1) % NUM_STAGES;
}
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
float s[DEAL_EACH_TIME];
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
for (int iter = 0; iter < num_iters; ++iter) {
wait_group<NUM_STAGES - 1>();
__syncthreads();
// compute qk
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
cu_q_smem,
kv_smem,
chunk_start + iter * DEAL_EACH_TIME,
stage_idx,
iter * DEAL_EACH_TIME,
chunk_len,
tidx,
gid,
scale,
s,
st
);
__syncthreads();
// compute sv
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
s,
kv_smem,
stage_idx,
iter * DEAL_EACH_TIME,
chunk_len,
tidx,
st
);
__syncthreads();
#pragma unroll
for (int j = 0; j < loop_times; ++j) {
const uint32_t k_seq_offset = j * bdy + gid;
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
kv_smem,
cache_k,
block_table_now,
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
stage_idx * DEAL_EACH_TIME + k_seq_offset,
kv_head_idx,
kv_num_heads,
tidx,
chunk_start,
chunk_end
);
}
commit_group();
stage_idx = (stage_idx + 1) % NUM_STAGES;
}
wait_group<0>();
__syncthreads();
// normize if not partition_kv
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
const uint32_t tile_id = vid / bdx;
if (!partition_kv || num_chunk_this_seq == 1) {
st.normalize(tile_id);
}
if (partition_kv && num_chunk_this_seq > 1) {
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
tmp_m[head_idx] = st.m;
tmp_d[head_idx] = st.d;
} else {
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
}
}
}
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
void MultiQueryDecoderAttention(
const AppendAttnMetaData& meta_data,
cudaStream_t &stream,
const paddle::Tensor &q,
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const int max_seq_len,
const int max_dec_len,
const float rope_scale,
const float rope_theta,
const float softmax_scale,
const float in_scale,
paddle::Tensor *out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto token_num = meta_data.token_nums;
auto bsz = meta_data.batch_size;
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
constexpr int num_stages = NUM_STAGE;
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
constexpr int blocky = GROUP_SIZE;
const int gridx = bsz;
constexpr int num_threads = blockx * blocky;
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
uint32_t cache_smem_bytes = 0;
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
const uint32_t chunk_size = get_max_partition_size(bsz);
const int num_chunks = div_up(max_dec_len, chunk_size);
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
assert(act_blocks_per_sm > 1);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
dim3 grids(gridx, num_chunks, kv_num_heads);
dim3 blocks(blockx, blocky);
if (num_chunks <= 1) {
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
cache_vec_size, blockx, blocky>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
softmax_scale,
in_scale,
chunk_size,
nullptr,
nullptr,
nullptr,
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
);
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
} else {
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
tmp_workspace = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads));
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
softmax_scale,
in_scale,
chunk_size,
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
);
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
constexpr int mblockx = HEAD_DIM_V / vec_size;
constexpr int bdy = 256 / mblockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(mblockx, bdy);
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
in_scale,
num_chunks,
chunk_size,
max_seq_len,
num_heads,
HEAD_DIM_V
);
}
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
}
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim_qk = meta_data.head_dims;
const auto head_dim_v = meta_data.head_dims_v;
const float rope_scale = 0.0;
const float rope_theta = 0.0;
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
const uint32_t num_stage = get_cascade_attention_num_stages();
const uint32_t num_threads = get_cascade_attention_num_threads();
DISPATCH_CAUSAL(causal, CAUSAL,
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);
template void DecodeMLAAttentionKernel<paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -0,0 +1,291 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mla_cache_kernel.cuh"
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto num_tokens = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
constexpr int PackSize = 16 / sizeof(DataType_);
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
prefill_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
return {};
}
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto bsz = meta_data.batch_size;
auto token_num = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
constexpr int PackSize = 16 / sizeof(DataType_);
const int blocksize = 128;
int grid_size = 1;
if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
} else {
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
}
return {};
}
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
PD_BUILD_OP(prefill_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"cum_offsets",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
PD_BUILD_OP(decode_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"cum_offsets",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int",
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));

View File

@@ -0,0 +1,242 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "mem_util.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1>
__global__ void decode_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
start_token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
start_token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}
template <typename T, int VecSize = 1>
__global__ void speculate_decode_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_id = linear_index / hidden_size;
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int write_seq_id =
seq_lens[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens[ori_bi],
token_id,
cum_offsets[ori_bi]);
}
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
token_id * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
token_id * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}
template <typename T, int VecSize = 1>
__global__ void prefill_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
const uint32_t ori_bi = ori_token_idx / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
const uint32_t block_offset = ori_seq_id % block_size;
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "utils.cuh"
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
int kv_num_heads; int kv_num_heads;
int token_nums; int token_nums;
int head_dims; int head_dims;
int head_dims_v;
int max_blocks_per_seq; int max_blocks_per_seq;
}; };
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} \ } \
} }
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
case 512: { \
constexpr size_t HEAD_DIM = 512; \
__VA_ARGS__ \
break; \
} \
case 576: { \
constexpr size_t HEAD_DIM = 576; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \ #define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
if (num_stage == 2) { \ if (num_stage == 2) { \
constexpr size_t NUM_STAGE = 2; \ constexpr size_t NUM_STAGE = 2; \
__VA_ARGS__ \ __VA_ARGS__ \
} else { \
PD_THROW("not support the num_stage: ", num_stage); \
} }
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \ #define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
@@ -328,8 +375,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \ constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
constexpr size_t cache_bytes = 4; \ constexpr size_t cache_bytes = 4; \
__VA_ARGS__ \ __VA_ARGS__ \
} else { \
PD_THROW("not support the cache_type: ", cache_type); \
} }
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \ #define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
if (deal_each_time == 32) { \ if (deal_each_time == 32) { \
constexpr size_t DEAL_EACH_TIME = 32; \ constexpr size_t DEAL_EACH_TIME = 32; \
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the group_size", group_size); \ PD_THROW("not support the group_size", group_size); \
} }
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 128) { \
constexpr size_t GROUP_SIZE = 128; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
}
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ #define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
if (block_shape_q <= 16) { \ if (block_shape_q <= 16) { \
constexpr size_t BLOCK_SHAPE_Q = 16; \ constexpr size_t BLOCK_SHAPE_Q = 16; \

View File

@@ -316,6 +316,96 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids, paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
int64_t num_experts); int64_t num_experts);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch);
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder);
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len);
void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
// [num_tokens, num_heads * head_size]
paddle::Tensor& key,
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
// head_size]
const paddle::Tensor& position_ids, // [num_tokens]
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
int head_size,
bool is_neox);
std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& query_bias,
const paddle::optional<paddle::Tensor>& query_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder);
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M); std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
@@ -370,6 +460,14 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
paddle::Tensor const &input, paddle::Tensor const &input,
paddle::Tensor &scales, float scale_ub); paddle::Tensor &scales, float scale_ub);
std::vector<paddle::Tensor> NoauxTc(
paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor);
PYBIND11_MODULE(fastdeploy_ops, m) { PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -627,6 +725,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("use_atomic_add"), py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"), py::arg("use_fp32_reduce"),
py::arg("is_zp_float")); py::arg("is_zp_float"));
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");
/** /**
@@ -653,4 +753,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
"dynamic_per_token_scaled_fp8_quant function", "dynamic_per_token_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function");
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
} }

64
custom_ops/gpu_ops/env.h Normal file
View File

@@ -0,0 +1,64 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
inline uint32_t get_decoder_block_shape_q() {
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
static const uint32_t decoder_block_shape_q =
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
return decoder_block_shape_q;
}
inline uint32_t get_encoder_block_shape_q() {
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
static const uint32_t encoder_block_shape_q =
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
return encoder_block_shape_q;
}
inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env));
return max_partition_size;
}
inline uint32_t get_cascade_attention_deal_each_time() {
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
static const uint32_t cascade_attention_deal_each_time =
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
}
inline uint32_t get_cascade_attention_num_stages() {
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
static const uint32_t cascade_attention_num_stages =
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
}
inline uint32_t get_cascade_attention_num_threads() {
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
static const uint32_t cascade_attention_num_threads =
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
}
inline bool get_mla_use_tensorcore() {
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
static const uint32_t mla_use_tensorcore =
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}

View File

@@ -0,0 +1,146 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding_kernel(
T* __restrict__ arr,
const T* __restrict__ cos_ptr,
const T* __restrict__ sin_ptr,
int rot_offset,
int embed_dim) {
int x_index, y_index;
T cos, sin;
if (IS_NEOX) {
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = cos_ptr[x_index];
sin = sin_ptr[x_index];
} else {
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = cos_ptr[x_index / 2];
sin = sin_ptr[x_index / 2];
}
const T x = arr[x_index];
const T y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template <typename T, bool IS_NEOX>
__global__ void apply_rotary_embedding_kernel(
T* __restrict__ query, // [num_tokens, num_heads, head_size]
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const int* __restrict__ position_ids, // [num_tokens]
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int pos = position_ids[token_idx];
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const T* cos_ptr = cache_ptr;
const T* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}
void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
// [num_tokens, num_heads * head_size]
paddle::Tensor& key,
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
// head_size]
const paddle::Tensor& position_ids, // [num_tokens]
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
int head_size,
bool is_neox) {
int64_t num_tokens = query.dims()[0];
int num_heads = query.numel() / num_tokens / head_size;
int num_kv_heads = key.numel() / num_tokens / head_size;
int rot_dim = cos_sin_cache.dims()[1];
int64_t query_stride = num_heads * head_size;
int64_t key_stride = num_kv_heads * head_size;
if (num_tokens > 65535) {
PD_THROW(
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
}
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
query.dtype(), "apply_rotary_embedding_kernel", [&] {
if (is_neox) {
apply_rotary_embedding_kernel<data_t, true>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
apply_rotary_embedding_kernel<data_t, false>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
PD_BUILD_OP(fused_rotary_position_encoding)
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
.Outputs({"query_out", "key_out"})
.Attrs({"head_size: int", "is_neox: bool"})
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
const int* seq_lens_this_time,
int* position_ids, // 输出的一维 position_ids
int* mask_encoder_batch,
const int bsz) { // 批次大小
// 当前线程索引(每个线程对应一个批次)
int tid = threadIdx.x;
if (tid >= bsz) return;
// 动态计算当前批次的偏移量
int offset = 0;
for (int i = 0; i < tid; i++) {
offset += seq_lens_encoder[i];
if (seq_lens_decoder[i] > 0) {
offset += seq_lens_this_time[i];
}
}
// 当前批次的 encoder 和 decoder 长度
int encoder_len = seq_lens_encoder[tid];
int decoder_len = seq_lens_decoder[tid];
int seq_len_this_time = seq_lens_this_time[tid];
// 写入 encoder 的 position_ids
for (int i = 0; i < encoder_len; i++) {
position_ids[offset + i] = i;
mask_encoder_batch[offset + i] = 1;
}
offset += encoder_len;
// 写入 decoder 的 position_ids
if (decoder_len > 0) {
for (int i = 0; i < seq_len_this_time; i++) {
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
mask_encoder_batch[offset + i] = 0;
}
}
}
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch) {
const int bsz = seq_lens_this_time.shape()[0];
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
const_cast<int*>(position_ids.data<int>()),
const_cast<int*>(mask_encoder_batch.data<int>()),
bsz);
}
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"position_ids",
"mask_encoder_batch"})
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"},
{"mask_encoder_batch", "mask_encoder_batch_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));

View File

@@ -39,10 +39,12 @@ namespace cub = hipcub;
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include "env.h"
#include "paddle/extension.h" #include "paddle/extension.h"
#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/cuda_stream.h" #include "paddle/phi/core/cuda_stream.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#ifndef PD_BUILD_STATIC_OP #ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -513,3 +515,10 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
cudaDevAttrMaxSharedMemoryPerBlockOptin, device); cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
return max_shared_mem_per_block_opt_in; return max_shared_mem_per_block_opt_in;
} }
inline int GetSMVersion() {
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;
}

View File

@@ -0,0 +1,255 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#include <cute/tensor.hpp>
#include <cutlass/detail/helper_macros.hpp>
#include "utils.cuh"
namespace mla_attn {
using namespace cute;
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); }
};
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
};
template <int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator& op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template <>
struct Allreduce<2> {
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator& op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& summary, Operator& op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0>& dst,
Tensor<Engine1, Layout1>& src, Operator& op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++) {
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& summary, Operator& op) {
thread_reduce_<init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& max) {
MaxOp<float> max_op;
reduce_<init>(tensor, max, max_op);
}
template <bool init, bool warp_reduce = true, typename Engine0, typename Layout0, typename Engine1,
typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& sum) {
SumOp<float> sum_op;
thread_reduce_<init>(tensor, sum, sum_op);
if constexpr (warp_reduce) {
quad_allreduce_(sum, sum, sum_op);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void apply_exp2(Tensor<Engine0, Layout0>& tensor,
Tensor<Engine1, Layout1> const& max) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
auto row_max = max(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = __expf(tensor(mi, ni) - row_max);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor,
Tensor<Engine1, Layout1> const& max,
const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
auto row_max = max(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// row_max * scale is a constant for each row, so we can use fma here
tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale);
}
}
}
template <int NUM_ROWS_PER_THREAD, bool WITH_SCALE>
struct OnlineSoftmax {
constexpr static float fill_value = -5e4;
using TensorT = decltype(make_tensor<float>(Shape<Int<NUM_ROWS_PER_THREAD>>{}));
TensorT row_max, row_sum, scores_scale;
float sm_scale_log2;
CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) {
clear(scores_scale);
};
__forceinline__ __device__ TensorT get_lse() const { return row_sum; }
template <bool init, typename Tensor0>
__forceinline__ __device__ TensorT update(Tensor0& acc_s) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
if constexpr (init) {
reduce_max</*init=*/true>(scores, row_max);
if constexpr (WITH_SCALE) {
scale_apply_exp2(scores, row_max, sm_scale_log2);
} else {
apply_exp2(scores, row_max);
}
reduce_sum</*init=*/true, /*warp_reduce=*/false>(scores, row_sum);
} else {
// update row_max
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*init=*/false>(scores, row_max);
// update scores_scale and scale row_sum
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
if constexpr (WITH_SCALE) {
scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2);
} else {
scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur);
}
row_sum(mi) *= scores_scale(mi);
}
// perform exp2 on scores
if constexpr (WITH_SCALE) {
scale_apply_exp2(scores, row_max, sm_scale_log2);
} else {
apply_exp2(scores, row_max);
}
// update row_sum
reduce_sum</*init=*/false, /*warp_reduce=*/false>(scores, row_sum);
return scores_scale;
}
};
template <typename Tensor0>
__forceinline__ __device__ TensorT finalize(Tensor0& acc_s) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.f / sum;
scores_scale(mi) = inv_sum;
row_max(mi) *= sm_scale_log2;
}
return scores_scale;
};
template <typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1& acc_o) {
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
template <typename Tensor1, typename Tensor2>
__forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) {
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale_input(mi);
}
}
};
};
} // namespace mla_attn

View File

@@ -0,0 +1,235 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <type_traits>
#include <vector>
#include "cute/tensor.hpp"
#include "mla_hopper.cuh"
#include <iostream>
#include <string>
#include <sstream>
#include "batch_mla_with_paged_kv_cache.h"
#include "env.h"
using namespace cute;
using namespace mla_attn;
using namespace std;
template <typename T>
struct cascade_type_traits {
using type = T;
using cutlass_type = T;
};
template <>
struct cascade_type_traits<phi::dtype::bfloat16> {
using type = __nv_bfloat16;
using cutlass_type = cutlass::bfloat16_t;;
};
template <>
struct cascade_type_traits<phi::dtype::float16> {
using type = half;
using cutlass_type = cutlass::half_t;
};
template <>
struct cascade_type_traits<phi::dtype::float8_e4m3fn> {
using type = __nv_fp8_e4m3;
using cutlass_type = cutlass::float_e4m3_t;
};
template <typename T>
void BatchMLAWithPagedKVCacheKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out) {
using NV_TYPE = typename cascade_type_traits<T>::type;
using CUTLASS_TYPE = typename cascade_type_traits<T>::cutlass_type;
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto q_head_num = meta_data.q_num_heads;
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
const auto max_block_num = bsz * max_block_num_per_seq;
const uint32_t chunk_size = get_max_partition_size(bsz);
int q_head_dim = meta_data.head_dims;
int k_head_dim = meta_data.head_dims;
int v_head_dim = meta_data.head_dims_v;
// int num_chunks = max_dec_len / chunk_size;
int num_chunks = div_up(max_dec_len, chunk_size);
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
O_tmp = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim));
m_tmp = allocator->Allocate(
sizeof(float) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
d_tmp = allocator->Allocate(
sizeof(float) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
Params<CUTLASS_TYPE, CUTLASS_TYPE, CUTLASS_TYPE, int> params = {};
params.Q = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(q.data<T>()));
params.KV = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(latent_cache.data<T>()));
params.O = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(out->data<T>()));
params.O_tmp = reinterpret_cast<CUTLASS_TYPE*>(O_tmp->ptr());
params.m = reinterpret_cast<float*>(m_tmp->ptr());
params.d = reinterpret_cast<float*>(d_tmp->ptr());
params.block_tables = const_cast<int*>(block_tables.data<int>());
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
params.num_blocks_x_int = num_blocks_x;
params.q_stride_bsz = q_head_num * q_head_dim;
params.q_stride_head_num = q_head_dim;
params.kv_stride_block_num = block_size * k_head_dim;
params.kv_stride_block_size = k_head_dim;
params.o_stride_bsz = q_head_num * v_head_dim;
params.o_stride_head_num = v_head_dim;
params.bsz = bsz;
params.token_num = token_num;
params.max_seq_len = max_seq_len;
params.max_block_num = max_block_num;
params.max_block_num_per_seq = max_block_num_per_seq;
params.q_num_head = q_head_num;
params.qk_head_dim = q_head_dim;
params.vo_head_dim = v_head_dim;
params.block_size = block_size;
params.max_draft_token_num = draft_token_num;
params.sm_scale = softmax_scale;
params.chunk_size = chunk_size;
params.chunk_num = num_chunks;
if (q_head_dim == 576) {
BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>(
params, stream
);
} else {
PD_THROW("error!!! q_head_dim must be 576 !!!\n");
}
}
template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);
template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "paddle/extension.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/allocator.h"
#include "append_attn/utils.cuh"
template <typename T>
void BatchMLAWithPagedKVCacheKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -0,0 +1,175 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
#define ATTENTION_HOPPER_EPILOGUE_CUH_
#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "named_barrier.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename Ktraits>
struct CollectiveEpilogue {
using DTypeO = typename Ktraits::DTypeO;
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
using StrideNTMAT = cute::Shape<int32_t, _1>;
using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
using TMA_O = decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O
static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
using TiledCopyOValLayout =
decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
using TiledCopyO =
decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout
TiledCopyOValLayout{} // Val layout
));
struct Arguments {
DTypeO* O_ptr;
LayoutNTMAT const layout_O;
DTypeO* O_ptr_tmp;
LayoutNTMAT const layout_O_tmp;
};
// Device side kernel params
struct Params {
DTypeO* O_ptr;
LayoutNTMAT const layout_O;
DTypeO* O_ptr_tmp;
LayoutNTMAT const layout_O_tmp;
};
static Params to_underlying_arguments_ntma(Arguments const& args) {
return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) {}
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
typename TiledMma>
CUTLASS_DEVICE void store(Params const& epilogue_params,
FrgTensorO const& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
const int thread_idx,
const int bid,
const int bsz,
const int seq_len_now,
const int start_token_idx,
const int tile_idx,
const int kv_len,
const int chunk_size,
const int max_draft_token_num,
const int o_stride_bsz) {
const int num_chunks = cute::ceil_div(kv_len, chunk_size);
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<DTypeO>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// make sure gemm done
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
// r2s
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
// make sure r2s done
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
TiledCopyO gmem_tiled_copy_O;
auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx)
ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D)
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D)
Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D)
Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D))
Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D))
Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D))
// copy if not out of bound
auto predicate_fn = [&](auto coords) {
auto s_coords = tOcOGroup(_0{}, coords);
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
};
copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
}
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_EPILOGUE_CUH_

View File

@@ -0,0 +1,163 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
#include <type_traits>
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
namespace mla_attn {
using namespace cute;
template <typename MainloopPipeline, typename MainloopPipelineQ, class DTypeQ, class DTypeKV, class DTypeQKAccum, class DTypeOut, class IdType,
int BLOCK_SHAPE_KV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutP, class SmemLayoutRow, class SmemLayoutO>
struct alignas(16) SharedStorageQKVO {
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutQ>> smem_q;
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutP>> smem_p;
alignas(16) cute::array_aligned<DTypeQKAccum, cute::cosize_v<SmemLayoutRow>> smem_scale;
union {
alignas(16) cute::array_aligned<DTypeKV, cute::cosize_v<SmemLayoutK>> smem_kv;
alignas(16) cute::array_aligned<DTypeOut, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q;
alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv;
};
};
template <bool USE_TMA_LOAD_KV_, int HEAD_DIM_QK_, int HEAD_DIM_VO_, int GROUP_SIZE_, int BLOCK_SHAPE_Q_, int BLOCK_SHAPE_KV_,
int NUM_STAGES_, typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_, typename NV_TYPE_>
struct AttentionKernelTraits {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
using DTypeQKAccum = float;
using DTypePVAccum = float;
using NV_TYPE = NV_TYPE_;
static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_;
static constexpr int GROUP_SIZE = GROUP_SIZE_;
static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_;
static_assert(BLOCK_SHAPE_Q % 64 == 0);
static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_;
static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_;
static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_;
static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK;
static_assert(HEAD_DIM_QK % 32 == 0);
static_assert(HEAD_DIM_VO % 32 == 0);
static constexpr int NUM_WARPS = 12;
static constexpr int NUM_THREADS = 384;
static constexpr int NUM_PRODUCER_THREADS = 128;
using TileShape_QKD = Shape<Int<BLOCK_SHAPE_Q>, Int<BLOCK_SHAPE_KV>, Int<HEAD_DIM_QK>>;
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
static constexpr int NUM_STAGES = NUM_STAGES_;
using AtomLayoutQKD = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _1, _1>>;
using AtomLayoutPV = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _2, _1>>;
using TiledMmaQK = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<DTypeQ, DTypeKV, DTypeQKAccum, TileShape_QKD>(), AtomLayoutQKD{}));
using TiledMmaPV = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutPV{}));
using TiledMmaPVSS = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutPV{}));
static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{});
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
decltype(cute::get<2>(TileShape_QKD{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})),
decltype(cute::get<2>(TileShape_QKD{}))>());
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int<NUM_STAGES>{})));
using SmemLayoutVt = decltype(composition(
SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}),
get<1>(TileShape_QKD{}), Int<NUM_STAGES>{}),
Step<_2, _1, _3>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{})));
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVtOneStage = decltype(composition(
SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}),
get<2>(TileShape_PDV{}), Int<1>{}),
Step<_2, _1, _3>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
using SmemCopyAtom = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeQ>;
static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32);
using SmemLayoutRowOneStage = Layout<Shape<_2, Int<128>>, Stride<_1, _2>>;
using SmemLayoutRowTwoStage = Layout<Shape<_2, Int<128>, _2>, Stride<_1, _2, _256>>;
using SmemLayoutRow = std::conditional_t<IS_CTA_32, SmemLayoutRowTwoStage, SmemLayoutRowOneStage>;
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
decltype(cute::get<1>(TileShape_QKD{}))>());
using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{})));
using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int<BLOCK_SHAPE_Q>{}, Int<BLOCK_SHAPE_KV>{}, Int<2>{})));
using SmemLayoutP = std::conditional_t<IS_CTA_32, SmemLayoutPSSTwoStage, SmemLayoutPSSOneStage>;
using MainloopPipelineQ = typename cutlass::PipelineAsync<1>;
using PipelineStateQ = typename cutlass::PipelineState<1>;
using MainloopPipeline =
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
typename cutlass::PipelineAsync<NUM_STAGES>>;
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;
using SharedStorage = SharedStorageQKVO<MainloopPipeline, MainloopPipelineQ, DTypeQ, DTypeKV, DTypeQKAccum, DTypeO, IdType, BLOCK_SHAPE_KV,
SmemLayoutQ, SmemLayoutK, SmemLayoutP, SmemLayoutRow, SmemLayoutO>;
};
} // namespace mla_attn
#endif

View File

@@ -0,0 +1,348 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "named_barrier.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename Ktraits, bool CAUSAL>
struct CollectiveMainloop {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = float;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8
static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512
static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8
using AlignmentTypeQ = cute::uint_byte_t<static_cast<int>(sizeof(DTypeQ)) * kGmemElemsPerLoad>;
using GmemCopyAtomQ = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<AlignmentTypeQ>, DTypeQ>;
static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS;
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(make_tiled_copy(
GmemCopyAtomQ{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemLayoutAtomQ = Layout<
Shape<Int<Ktraits::NUM_PRODUCER_THREADS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyQ = decltype(make_tiled_copy(
GmemCopyAtomQ{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using ShapeQT = cute::Shape<int32_t, int32_t>;
using StrideQT = cute::Shape<int32_t, _1>;
using LayoutQT = cute::Layout<ShapeQT, StrideQT>;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeMDT = cute::Shape<int32_t, int32_t>;
using StrideMDT = cute::Shape<int32_t, _1>;
using LayoutMDT = cute::Layout<ShapeMDT, StrideMDT>;
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<DTypeKV const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)), StrideT{}
),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_QKD{}),
_1{})); // no mcast for KV
static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
static constexpr uint32_t TmaTransactionBytesQ =
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
static constexpr uint32_t TmaTransactionBytesKV =
static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<DTypeKV> / 8);
// Host side kernel arguments
struct Arguments {
LayoutQT layout_Q;
LayoutT layout_KV;
LayoutMDT layout_MD;
DTypeQ const* Q_ptr;
DTypeKV const* KV_ptr;
DTypeMD const* m_ptr;
DTypeMD const* d_ptr;
IdType const* kv_block_tables;
IdType const* seq_lens_this_time;
IdType const* seq_lens_encoder;
IdType const* seq_lens_decoder;
IdType const* cumsum_q_seqlens;
IdType const* batch_ids;
IdType const* tile_ids_per_batch;
IdType const* num_blocks_x;
float sm_scale;
int bsz;
int max_block_num;
int max_block_num_per_seq;
int q_stride_bsz;
int q_stride_head_num;
int kv_stride_block_num;
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
int chunk_num;
int max_draft_token_num;
};
// Device side kernel params
struct Params {
LayoutQT layout_Q;
LayoutT layout_KV;
LayoutMDT layout_MD;
DTypeQ *Q_ptr;
DTypeKV* KV_ptr;
DTypeMD* m_ptr;
DTypeMD* d_ptr;
IdType* kv_block_tables;
IdType* seq_lens_this_time;
IdType* seq_lens_encoder;
IdType* seq_lens_decoder;
IdType* cumsum_q_seqlens;
IdType* batch_ids;
IdType* tile_ids_per_batch;
IdType* num_blocks_x;
float sm_scale;
int bsz;
int max_block_num;
int max_block_num_per_seq;
int q_stride_bsz;
int q_stride_head_num;
int kv_stride_block_num;
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
int chunk_num;
int max_draft_token_num;
TMA_KV tma_load_KV;
};
static Params to_underlying_arguments(Arguments const& args) {
TMA_KV tma_load_KV;
if constexpr (USE_TMA_LOAD_KV) {
Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV);
tma_load_KV =
make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{});
}
return {args.layout_Q,
args.layout_KV,
args.layout_MD,
const_cast<DTypeQ*>(args.Q_ptr),
const_cast<DTypeKV*>(args.KV_ptr),
const_cast<DTypeMD*>(args.m_ptr),
const_cast<DTypeMD*>(args.d_ptr),
const_cast<IdType*>(args.kv_block_tables),
const_cast<IdType*>(args.seq_lens_this_time),
const_cast<IdType*>(args.seq_lens_encoder),
const_cast<IdType*>(args.seq_lens_decoder),
const_cast<IdType*>(args.cumsum_q_seqlens),
const_cast<IdType*>(args.batch_ids),
const_cast<IdType*>(args.tile_ids_per_batch),
const_cast<IdType*>(args.num_blocks_x),
args.sm_scale,
args.bsz,
args.max_block_num,
args.max_block_num_per_seq,
args.q_stride_bsz,
args.q_stride_head_num,
args.kv_stride_block_num,
args.kv_stride_block_size,
args.o_stride_bsz,
args.o_stride_head_num,
args.chunk_size,
args.chunk_num,
args.max_draft_token_num,
tma_load_KV
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
if constexpr (USE_TMA_LOAD_KV) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor());
}
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_q(Params const& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_write_q,
SharedStorage& shared_storage,
const int thread_idx,
const int bid) {
int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid];
int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx;
Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q);
Tensor gQ =
local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor cQ = cute::make_identity_tensor(gQ.shape());
GmemTiledCopyQ gmem_tiled_copy_q;
auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx);
Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ);
Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ);
Tensor tQcQGroup = flatten_1(tQcQ);
int valid_q_size = mainloop_params.seq_lens_this_time[bid];
auto q_predicate_fn = [&](auto coords) {
auto s_coords = tQcQGroup(_0{}, coords);
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size);
};
Tensor tQgQiGroup = flatten_1(tQgQ);
Tensor tQsQiGroup = flatten_1(tQsQ);
pipeline_q.producer_acquire(smem_pipe_write_q);
copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup);
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_q;
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_kv(Params const& mainloop_params,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_write_kv,
SharedStorage& shared_storage,
const int bid,
const int kv_len,
const int tile_idx) {
int thread_idx = threadIdx.x;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV);
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
GmemTiledCopy gmem_tiled_copy_kv;
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV);
Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
const int block_idx = kv_block_tables(bid, kv_tile_idx);
pipeline_kv.producer_acquire(smem_pipe_write_kv);
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx));
Tensor tKsKiGroup =
flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index()));
copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup);
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_kv;
}
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_write_kv,
SharedStorage& shared_storage,
const int bid,
const int kv_len,
const int tile_idx) {
int thread_idx = threadIdx.x;
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape());
// Prepare the TMA loads
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
auto [tKgK, tKsK] =
tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{},
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
#pragma unroll 2
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
const int block_idx = kv_block_tables(bid, kv_tile_idx);
pipeline_kv.producer_acquire(smem_pipe_write_kv);
copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0),
tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index()));
++smem_pipe_write_kv;
}
}
}
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_

View File

@@ -0,0 +1,500 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "named_barrier.cuh"
// #define DEBUG_MLA
namespace mla_attn {
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_read_q,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_read_kv,
FrgTensorO& tOrO,
AttentionUpdater& attention_updater,
const int thread_idx,
const int bid,
const int kv_len,
const int qo_len,
const int tile_idx,
SharedStorage& shared_storage) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = typename Ktraits::DTypeO;
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutP = typename Ktraits::SmemLayoutP;
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head)
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
typename Ktraits::TiledMmaQK tiled_mma_qk;
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup);
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 1) {
// consumer 0, compute qk
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1;
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
bool is_first_step = true;
// wait q
consumer_wait(pipeline_q, smem_pipe_read_q);
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
#pragma unroll 1
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
// wait kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// gemm qk
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
// mask
if (masking_step > 0) {
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
// update s (exp(s - m))
Tensor scale_o = is_first_step ? attention_updater.update</*init=*/true>(tSrS) : attention_updater.update</*init=*/false>(tSrS);
is_first_step = false;
Tensor convert_tSrS = convert_type<DTypeKV>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS);
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP);
cute::copy(scale_o, tScalesScale);
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
// make sure r2s all done
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
attention_updater.rescale_o(tOrO, scale_o);
// pv gemm
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV1(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV2(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
// sync WG1 WG2
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
}
// release q
pipeline_q.consumer_release(smem_pipe_read_q);
++smem_pipe_read_q;
// normalize
Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum
if (chunk_num_this_seq == 1) {
// norm
cute::copy(scale_o, tScalesScale);
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
attention_updater.rescale_o(tOrO, scale_o);
}
// WG1 write m,d back to gmem
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8t4->row1 row9
const int warp_idx = thread_idx / 32;
#pragma unroll
for (int w_i = 0; w_i < 2; ++w_i) {
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
if (token_idx < qo_len) {
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
}
}
}
} else if (warp_group_idx == 2) {
// consumer 1, compute pv
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
// wait kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// A: tPsP
cute::copy(tScalesScale, scale_o);
// rescale
attention_updater.rescale_o(tOrO, scale_o);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV1(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV2(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
// sync WG1 WG2
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
}
if (chunk_num_this_seq == 1) {
// norm
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
cute::copy(tScalesScale, scale_o);
attention_updater.rescale_o(tOrO, scale_o);
}
}
return;
}
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_read_q,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_read_kv,
FrgTensorO& tOrO,
AttentionUpdater& attention_updater,
const int thread_idx,
const int bid,
const int kv_len,
const int qo_len,
const int tile_idx,
SharedStorage& shared_storage) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = typename Ktraits::DTypeO; // !!! bf16
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutP = typename Ktraits::SmemLayoutP;
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _);
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
typename Ktraits::TiledMmaQK tiled_mma_qk;
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _);
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3);
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 1) {
// consumer 0, compute qk
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
// wait q
consumer_wait(pipeline_q, smem_pipe_read_q);
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
// wait k
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// first qk gemm
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
// mask
{
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
Tensor scale_o = attention_updater.update</*init=*/true>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0;
--kv_tile_idx;
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv;
++smem_pipe_read_kv;
// wait next kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// gemm next qk
gemm</*init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
attention_updater.rescale_o(tOrO);
// last pv gemm
if (smem_pipe_read_kv_cur.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv_cur.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv_cur.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
// wait cur qk gemm
warpgroup_wait<1>();
// mask p
if (masking_step > 0) {
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
// update s (exp(s - m))
Tensor scale_o = attention_updater.update</*init=*/false>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
// make sure tSrS r2s done
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// wait last pv gemm
warpgroup_wait<0>();
// release last kv
pipeline_kv.consumer_release(smem_pipe_read_kv_cur);
}
// release q
pipeline_q.consumer_release(smem_pipe_read_q);
++smem_pipe_read_q;
// compute last pv
attention_updater.rescale_o(tOrO);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
scale_o = attention_updater.finalize(tSrS);
warpgroup_wait<0>();
// release last kv
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
if (chunk_num_this_seq == 1) {
// norm
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
attention_updater.rescale_o(tOrO);
}
// WG1 write m,d back to gmem
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8t4->row1 row9
const int warp_idx = thread_idx / 32;
#pragma unroll
for (int w_i = 0; w_i < 2; ++w_i) {
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
if (token_idx < qo_len) {
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
}
}
}
} else if (warp_group_idx == 2) {
// consumer 1, compute pv
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
consumer_wait(pipeline_kv, smem_pipe_read_kv);
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// A: tPsP
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
// rescale
attention_updater.rescale_o(tOrO, scale_o);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
}
if (chunk_num_this_seq == 1) {
// norm
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
attention_updater.rescale_o(tOrO, scale_o);
}
}
return;
}
} // namespace mla_attn
#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_

View File

@@ -0,0 +1,575 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <type_traits>
#include <vector>
#include "attention_updater.cuh"
#include "cute/tensor.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "epilogue.cuh"
#include "helper.h"
#include "kernel_traits.cuh"
#include "mainloop_mma.cuh"
#include "mainloop_load.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
struct Params {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_encoder;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *padding_offsets;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size;
uint32_t o_stride_bsz;
uint32_t o_stride_head_num;
int bsz;
int token_num;
int max_seq_len;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
int qk_head_dim;
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_size;
int chunk_num;
int num_blocks_x_int;
float sm_scale;
};
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
}
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeO = typename Ktraits::DTypeO;
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
const int num_blocks_x = mainloop_params.num_blocks_x[0];
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
extern __shared__ char shared_memory[];
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
if constexpr (use_tma_load_kv) {
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NUM_MMA_THREADS;
} else {
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
}
PipelineParamsQ pipeline_params_q;
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
MainloopPipeline pipeline_kv = [&] {
if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{});
} else {
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
}
}();
__syncthreads();
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue;
if (warp_group_idx == 0) {
// producer
if constexpr(USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_dealloc<72>();
}
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
const int block_id = blockIdx.x;
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
// consumer
if constexpr(USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_alloc<216>();
}
PipelineStateQ smem_pipe_read_q;
PipelineState smem_pipe_read_kv;
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
} else {
const int block_id = blockIdx.x;
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
}
}
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType;
using NV_TYPE = typename KernelTraits::NV_TYPE;
using CollectiveMainloop =
CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
params.Q,
params.KV,
params.m,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_encoder,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.sm_scale,
params.bsz,
params.max_block_num,
params.max_block_num_per_seq,
params.q_stride_bsz,
params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_size,
params.chunk_num,
params.max_draft_token_num
});
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
// Get the ptr to kernel function.
auto kernel =
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
int device;
cudaGetDevice(&device);
int multiprocessor_count;
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
int act_blocks_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
int gridx;
if constexpr(USE_FIXED_BLOCK) {
gridx = multiprocessor_count;
} else {
gridx = params.num_blocks_x_int;
}
dim3 grid_dims = {gridx, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
mainloop_params, epilogue_params
);
if (params.chunk_num > 1) {
constexpr int vec_size = 16 / sizeof(DTypeO);
constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(params.O_tmp),
params.m,
params.d,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.seq_lens_encoder,
params.padding_offsets,
reinterpret_cast<NV_TYPE*>(params.O),
params.max_seq_len,
params.chunk_num,
params.q_num_head,
params.chunk_size,
params.vo_head_dim,
params.token_num,
params.bsz,
params.max_draft_token_num
);
}
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,
/*BLOCK_SHAPE_Q_=*/64,
/*BLOCK_SHAPE_KV_=*/64,
/*NUM_STAGES_=*/2,
typename Params::DTypeQ,
typename Params::DTypeKV,
typename Params::DTypeO,
typename Params::IdType,
NV_TYPE>,
CAUSAL,
Params,
USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
} else {
return cudaErrorNotSupported;
}
return cudaSuccess;
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
#include <cuda_runtime.h>
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
namespace mla_attn {
enum class NamedBarriers {
kQueryEmpty = 0,
kValueEmpty = 1,
kWarpSchedulerWG1 = 2,
kWarpSchedulerWG2 = 3,
kWarpSchedulerWG3 = 4,
kPrefetchIndices = 5,
kOdone = 6,
kWG1WG2Sync = 7,
kWG0WG1WG2Sync = 8,
kWG1WG2LastSync = 9,
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_

View File

@@ -0,0 +1,351 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_UTILS_CUH_
#define ATTENTION_HOPPER_UTILS_CUH_
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include <assert.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <stdlib.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cuda_runtime.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cmath>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "cutlass/fast_math.h"
namespace mla_attn {
using namespace cute;
template <typename TensorT>
CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) {
Tensor tensor_flatten = cute::flatten(tensor);
return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten);
}
CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride,
int64_t h_stride) {
return make_layout(make_shape(nnz, head_dim, num_heads),
make_stride(n_stride, cute::_1{}, h_stride));
}
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) {
return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads)));
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
int head_idx, int offset, int seq_len) {
auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)),
make_coord(offset, _0{}));
auto g_sequence =
make_tensor(g_offset.data(),
make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride()));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
int head_idx, int offset, int seq_len) {
auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset));
auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len),
cute::make_shape(shape<0>(m_tensor))));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
return g_tensor;
}
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V,
// MMA_N))
template <typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
};
// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16,
// MMA_N))
template <typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout),
make_layout(get<2, 1>(l), get<2>(acc_layout)));
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <bool init = false, int wg_wait = 0, typename TensorA, typename TensorB, typename TensorC,
typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB,
TensorC& tCrC) {
constexpr bool Is_RS =
!cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
}
warpgroup_fence_operand(tCrC);
warpgroup_arrive();
if constexpr (init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
}
}
#define HOSTDEVICE __host__ __device__
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
};
template <typename T, int Size>
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
const AlignedVector<T, Size>* addr_vec =
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
*vec = *addr_vec;
}
template <typename T, int Size>
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
AlignedVector<T, Size>* addr_vec =
reinterpret_cast<AlignedVector<T, Size>*>(addr);
*addr_vec = vec;
}
template <size_t vec_size, typename T>
struct prefill_softmax_state_t {
AlignedVector<T, vec_size> o;
float m;
float d;
__device__ __forceinline__ void init() {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = -3.38953e38f;
}
}
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
const float other_m,
const float other_d) {
float m_prev = m, d_prev = d;
m = max(m_prev, other_m);
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
d = d_prev * scale1 + other_d * scale2;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
}
}
__device__ __forceinline__ void normalize() {
const T d_t = static_cast<T>(d);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
}
}
};
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
const int * __restrict__ seq_lens_this_time,
const int * __restrict__ seq_lens_decoder,
const int * __restrict__ seq_lens_encoder,
const int * __restrict__ padding_offsets,
T * __restrict__ out, // [token_num, num_heads, head_dim]
const int max_seq_len,
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim,
const int token_num,
const int bsz,
const int max_draft_token_num) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t ori_token_id = qid + padding_offsets[qid];
const uint32_t bid = ori_token_id / max_seq_len;
const int seq_len_q = seq_lens_this_time[bid];
if (seq_len_q == 0) continue;
const uint32_t local_seq_id = ori_token_id % max_seq_len;
int seq_len_kv = seq_lens_decoder[bid];
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
// not need merge
continue;
}
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
float d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset;
offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid;
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
d = d * scale1 + d_now * scale2;
#pragma unroll
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
if (ty == 0) {
// merge bdy
prefill_softmax_state_t<vec_size, T> st;
st.init();
#pragma unroll
for (int i = 0; i < bdy; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
Store<T, vec_size>(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
}
__syncthreads();
}
}
} // namespace mla_attn
#endif // ATTENTION_HOPPER_UTILS_CUH_

View File

@@ -1255,8 +1255,6 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float) { if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) { if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1; int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) { if constexpr (w_type.size_bits() == 4) {

View File

@@ -0,0 +1,469 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "append_attn/multi_head_latent_attention_kernel.h"
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
template <paddle::DataType D>
std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& query_bias,
const paddle::optional<paddle::Tensor>& query_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const std::string& cache_quant_type_str,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
typedef PDTraits<D> traits_;
typedef typename traits_::data_t data_t;
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];
const bool mla_use_tensorcore = get_mla_use_tensorcore();
auto sm_version = GetSMVersion();
if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) {
PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90.");
}
auto main_stream = query.stream();
paddle::Tensor fmha_out = paddle::full(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
0,
D,
query.place());
if (max_dec_len_this_time_data > 0) {
if (mla_use_tensorcore) {
BatchMLAWithPagedKVCacheKernel<data_t>(meta_data,
query,
key_cache,
attn_mask,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
cache_quant_type_str,
decoder_num_blocks_data,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
main_stream,
&fmha_out);
} else {
DecodeMLAAttentionKernel<data_t>(
meta_data,
query, // [token_num, num_heads, head_dim]
key_cache,
value_cache,
attn_mask,
out_linear_shifts,
out_linear_smooths,
seq_lens_this_time, // q_seq_len is 1
seq_lens_decoder,
padding_offsets,
cum_offsets,
block_tables,
max_input_length,
max_len_kv_data,
softmax_scale,
out_linear_in_scale,
causal,
main_stream,
&fmha_out);
}
}
return {fmha_out};
}
std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& query_bias,
const paddle::optional<paddle::Tensor>& query_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
AppendAttnMetaData meta_data;
const auto& query_dims = query.dims();
const auto& key_cache_dims = key_cache.dims();
const int q_hidden_size = query_dims[query_dims.size() - 1];
meta_data.token_nums = query_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = cum_offsets.dims()[0];
switch (query.dtype()) {
case paddle::DataType::BFLOAT16: {
return MultiHeadLatentAttentionKernel<paddle::DataType::BFLOAT16>(
meta_data,
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_num_blocks_cpu,
max_enc_len_this_time,
max_dec_len_this_time,
max_len_kv,
attn_mask,
query_bias,
query_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
cache_quant_type_str,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
}
case paddle::DataType::FLOAT16: {
return MultiHeadLatentAttentionKernel<paddle::DataType::FLOAT16>(
meta_data,
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_num_blocks_cpu,
max_enc_len_this_time,
max_dec_len_this_time,
max_len_kv,
attn_mask,
query_bias,
query_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
cache_quant_type_str,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& query_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& encoder_num_blocks_shape,
const std::vector<int64_t>& kv_batch_ids_shape,
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
const std::vector<int64_t>& kv_num_blocks_shape,
const std::vector<int64_t>& decoder_batch_ids_shape,
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
const std::vector<int64_t>& max_enc_len_this_time_shape,
const std::vector<int64_t>& max_dec_len_this_time_shape,
const std::vector<int64_t>& max_len_kv_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& query_bias_shape,
const paddle::optional<std::vector<int64_t>>& query_out_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const int token_num = query_shape[0];
const int kv_num_heads = key_cache_shape[1];
const int head_dim_qk = key_cache_shape[3];
const int head_dim_v = nope_size;
const int q_hidden_size = query_shape[query_shape.size() - 1];
const int num_heads = q_hidden_size / head_dim_qk;
return {{token_num, num_heads * head_dim_v}};
}
std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& query_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
const paddle::DataType& encoder_num_blocks_dtype,
const paddle::DataType& kv_batch_ids_dtype,
const paddle::DataType& kv_tile_ids_per_batch_dtype,
const paddle::DataType& kv_num_blocks_dtype,
const paddle::DataType& decoder_batch_ids_dtype,
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& decoder_num_blocks_cpu_dtype,
const paddle::DataType& max_enc_len_this_time_dtype,
const paddle::DataType& max_dec_len_this_time_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& query_bias_dtype,
const paddle::optional<paddle::DataType>& query_out_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
if (compute_dtype == "bf16") {
return {paddle::DataType::BFLOAT16};
} else if (compute_dtype == "fp16") {
return {paddle::DataType::FLOAT16};
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
}
}
PD_BUILD_OP(multi_head_latent_attention)
.Inputs({"query",
"key_cache",
"value_cache",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"cu_seqlens_q",
"padding_offsets",
"cum_offsets",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"decoder_num_blocks_cpu",
"max_enc_len_this_time",
"max_dec_len_this_time",
"max_len_kv",
paddle::Optional("attn_mask"),
paddle::Optional("query_bias"),
paddle::Optional("query_out_scales"),
paddle::Optional("cache_k_quant_scales"),
paddle::Optional("cache_v_quant_scales"),
paddle::Optional("cache_k_dequant_scales"),
paddle::Optional("cache_v_dequant_scales"),
paddle::Optional("cache_k_zp"),
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths")})
.Outputs({"fmha_out"})
.Attrs({"compute_type: std::string",
"cache_quant_type: std::string",
"nope_size: int",
"max_input_length: int",
"softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(MultiHeadLatentAttention))
.SetInferShapeFn(PD_INFER_SHAPE(MultiHeadLatentAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MultiHeadLatentAttentionInferDtype));

View File

@@ -0,0 +1,73 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <optional>
#include "helper.h"
#include "noauxtc_kernel.h"
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor) {
auto input_shape = scores_with_bias.shape();
int64_t num_tokens = input_shape[0];
int64_t num_experts = input_shape[1];
auto input_type = scores_with_bias.dtype();
auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto stream = scores_with_bias.stream();
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
num_tokens,
num_experts,
n_group,
topk_group,
topk,
routed_scaling_factor,
stream);
return {scores};
}
std::vector<paddle::DataType> NoauxTcInferDtype(
const paddle::DataType& scores_dtype,
const paddle::DataType& scores_with_bias_dtype) {
return {scores_dtype};
}
std::vector<std::vector<int64_t>> NoauxTcInferShape(
const std::vector<int64_t>& scores_shape,
const std::vector<int64_t>& gating_output_shape) {
return {scores_shape};
}
PD_BUILD_OP(noaux_tc)
.Inputs({"scores", "scores_with_bias"})
.Outputs({"output_tensor"})
.Attrs({"n_group: int",
"topk_group: int",
"topk:int",
"routed_scaling_factor: float"})
.SetKernelFn(PD_KERNEL(NoauxTc))
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcInferDtype));

View File

@@ -0,0 +1,551 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This code is partially inspired by and references the implementation found
// in NVIDIA TRTLLM.
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
namespace warp_topk {
template <int size, typename T>
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
if (len == 0) {
return 0;
}
return ((len - 1) / size + 1) * size;
}
template <typename T>
constexpr __host__ __device__ bool isPowerOf2(T v) {
return (v && !(v & (v - 1)));
}
template <bool greater, typename T>
__device__ bool is_better_than(T val, T baseline) {
return (val > baseline && greater) || (val < baseline && !greater);
}
template <typename T, typename idxT>
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
return max(cache_topk,
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}
template <int size, bool ascending, typename T, typename idxT>
struct BitonicMerge {
// input should be a bitonic sequence, and sort it to be a monotonic sequence
__device__ static void merge(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
static_assert(isPowerOf2(size));
static_assert(size >= 2 * WARP_SIZE);
constexpr int arr_len = size / WARP_SIZE;
constexpr int stride = arr_len / 2;
for (int i = 0; i < stride; ++i) {
int const other_i = i + stride;
T& val = val_arr[i];
T& other_val = val_arr[other_i];
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
T tmp = val;
val = other_val;
other_val = tmp;
idxT tmp2 = idx_arr[i];
idx_arr[i] = idx_arr[other_i];
idx_arr[other_i] = tmp2;
}
}
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
idx_arr + arr_len / 2);
}
};
template <int size, bool ascending, typename T, typename idxT>
struct BitonicSort {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
static_assert(isPowerOf2(size));
static_assert(size >= 2 * WARP_SIZE);
constexpr int arr_len = size / WARP_SIZE;
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
idx_arr + arr_len / 2);
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
}
};
template <bool ascending, typename T, typename idxT>
struct BitonicSort<32, ascending, T, idxT> {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
// ascending doesn't matter before merging since all we need is a bitonic
// sequence
for (int stage = 0; stage < 4; ++stage) {
for (int stride = (1 << stage); stride > 0; stride /= 2) {
bool reverse = (lane >> stage) & 2;
bool is_second = lane & stride;
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
*val_arr = other;
*idx_arr = other_idx;
}
}
}
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
}
};
template <bool ascending, typename T, typename idxT>
struct BitonicMerge<32, ascending, T, idxT> {
__device__ static void merge(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
bool is_second = lane & stride;
T& val = *val_arr;
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
idxT& idx = *idx_arr;
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
if (val != other && ((val > other) == (ascending != is_second))) {
val = other;
idx = other_idx;
}
}
}
};
template <int capacity, bool greater, typename T, typename idxT>
class WarpSort {
public:
__device__ WarpSort(idxT k, T dummy)
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
for (int i = 0; i < max_arr_len_; ++i) {
val_arr_[i] = dummy_;
}
}
// load and merge k sorted values
__device__ void load_sorted(T const* __restrict__ in,
idxT const* __restrict__ in_idx,
idxT start) {
idxT idx = start + WARP_SIZE - 1 - lane_;
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
if (idx < start + k_) {
T t = in[idx];
if (is_better_than<greater>(t, val_arr_[i])) {
val_arr_[i] = t;
idx_arr_[i] = in_idx[idx];
}
}
}
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
}
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
for (int i = 0; i < max_arr_len_; ++i) {
idxT out_i = i * WARP_SIZE + lane_;
if (out_i < k_) {
out[out_i] = val_arr_[i];
out_idx[out_i] = idx_arr_[i];
}
}
}
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
for (int i = 0; i < max_arr_len_; ++i) {
idxT out_i = i * WARP_SIZE + lane_;
if (out_i < k_) {
out_idx[out_i] = idx_arr_[i];
}
}
}
protected:
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
T val_arr_[max_arr_len_];
idxT idx_arr_[max_arr_len_];
int const lane_;
idxT const k_;
T const dummy_;
}; // end class WarpSort
template <int capacity, bool greater, typename T, typename idxT>
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
public:
__device__ WarpSelect(idxT k, T dummy)
: WarpSort<capacity, greater, T, idxT>(k, dummy),
k_th_(dummy),
k_th_lane_((k - 1) % WARP_SIZE) {
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
int const num_of_warp = blockDim.x / WARP_SIZE;
int const warp_id = threadIdx.x / WARP_SIZE;
val_smem_ = reinterpret_cast<T*>(smem_buf);
val_smem_ += warp_id * WARP_SIZE;
idx_smem_ = reinterpret_cast<idxT*>(
smem_buf +
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
idx_smem_ += warp_id * WARP_SIZE;
}
__device__ void add(T const* in, idxT start, idxT end) {
idxT const end_for_fullwarp =
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
T val = (i < end) ? in[i] : dummy_;
add(val, i);
}
}
__device__ void add(T val, idxT idx) {
bool do_add = is_better_than<greater>(val, k_th_);
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
if (mask == 0) {
return;
}
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
if (do_add && pos < WARP_SIZE) {
val_smem_[pos] = val;
idx_smem_[pos] = idx;
do_add = false;
}
smem_buf_len_ += __popc(mask);
if (smem_buf_len_ >= WARP_SIZE) {
__syncwarp();
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
smem_buf_len_ -= WARP_SIZE;
}
if (do_add) {
pos -= WARP_SIZE;
val_smem_[pos] = val;
idx_smem_[pos] = idx;
}
__syncwarp();
}
__device__ void done() {
if (smem_buf_len_) {
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
merge_buf_(val, idx);
}
// after done(), smem is used for merging results among warps
__syncthreads();
}
private:
__device__ void set_k_th_() {
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
}
__device__ void merge_buf_(T val, idxT idx) {
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
T& old = val_arr_[max_arr_len_ - 1];
if (is_better_than<greater>(val, old)) {
old = val;
idx_arr_[max_arr_len_ - 1] = idx;
}
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
set_k_th_();
}
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
using WarpSort<capacity, greater, T, idxT>::val_arr_;
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
using WarpSort<capacity, greater, T, idxT>::lane_;
using WarpSort<capacity, greater, T, idxT>::k_;
using WarpSort<capacity, greater, T, idxT>::dummy_;
T* val_smem_;
idxT* idx_smem_;
int smem_buf_len_ = 0;
T k_th_;
int const k_th_lane_;
}; // end class WarpSelect
} // namespace warp_topk
template <typename T>
__device__ void topk_with_k2(T* output,
T const* input,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group) {
// Get the top2 per thread
T largest = cuda::std::numeric_limits<T>::min();
T second_largest = cuda::std::numeric_limits<T>::min();
if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i];
if (value > largest) {
second_largest = largest;
largest = value;
} else if (value > second_largest) {
second_largest = value;
}
}
} else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
largest = input[i];
}
}
__syncwarp(); // Ensure all threads have valid data before reduction
// Get the top2 warpwise
T max1 = cg::reduce(tile, largest, cg::greater<T>());
T max2 = max1;
bool equal_to_max1 = (max1 == largest);
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
if (count_max1 == 1) {
largest = (largest == max1) ? second_largest : largest;
max2 = cg::reduce(tile, largest, cg::greater<T>());
}
if (lane_id == 0) {
*output = max1 + max2;
}
}
template <typename T>
__global__ void topk_with_k2_kernel(T* output,
T* input,
int64_t const num_tokens,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
if (case_id < num_cases) {
input += case_id * num_experts_per_group;
output += case_id;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
}
}
template <typename T>
__global__ void group_idx_and_topk_idx_kernel(
T* scores,
T const* group_scores,
T* scores_with_bias,
int64_t const num_tokens,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
int64_t const num_experts,
int64_t const num_experts_per_group,
double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
// store the target topk idx
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
T* s_topk_value =
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
warp_id * topk;
T value = cuda::std::numeric_limits<T>::min();
T topk_group_value = cuda::std::numeric_limits<T>::min();
int32_t num_equalto_topkth_group;
if ((n_group > topk_group) && (case_id < num_tokens)) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group) {
value = group_scores[lane_id];
}
int count_equal_to_top_value = WARP_SIZE - n_group;
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = cuda::std::numeric_limits<T>::min();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
int count_equalto_topkth_group = 0;
if (case_id < num_tokens) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) {
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = i < num_experts_per_group
? scores_with_bias[offset + i]
: cuda::std::numeric_limits<T>::min();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
count_equalto_topkth_group++;
}
}
}
queue.done();
__syncwarp();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
__syncwarp();
}
// Load the valid score value
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value = i < topk ? scores[s_topk_idx[i]]
: 0.0f; // Load the valid value of expert
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, value, cg::plus<float>());
}
}
__syncthreads();
if (case_id < num_tokens) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__threadfence();
__syncthreads();
if (case_id < num_tokens) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
scores[s_topk_idx[i]] = value;
}
}
}
template <typename T>
void invokeNoAuxTc(T* scores,
T* group_scores,
T* scores_with_bias,
int64_t const num_tokens,
int64_t const num_experts,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
double const routed_scaling_factor,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores,
scores_with_bias,
num_tokens,
num_cases,
n_group,
num_experts / n_group);
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
routed_scaling_factor);
}
#define INSTANTIATE_NOAUX_TC(T) \
template void invokeNoAuxTc<T>(T * scores, \
T * group_scores, \
T * scores_with_bias, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
double const routed_scaling_factor, \
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float);

View File

@@ -50,11 +50,13 @@ __global__ void quant_per_token_per_block(const T *input,
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
} }
// get max value per warp // get max value per warp
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread); max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread); max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread); max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread); max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread); max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
// broadcast max_value
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
max_value_thread = max(max_value_thread, epsilon); max_value_thread = max(max_value_thread, epsilon);
float scale_to_store = max_value_thread / MAX_VALUE; float scale_to_store = max_value_thread / MAX_VALUE;
// quant // quant

View File

@@ -267,6 +267,9 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/text_image_index_out.cu", "gpu_ops/text_image_index_out.cu",
"gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu", "gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
] ]
# pd_disaggregation # pd_disaggregation
@@ -376,6 +379,8 @@ elif paddle.is_compiled_with_cuda():
# append_attention # append_attention
sources += ["gpu_ops/append_attention.cu"] sources += ["gpu_ops/append_attention.cu"]
sources += find_end_files("gpu_ops/append_attn", ".cu") sources += find_end_files("gpu_ops/append_attn", ".cu")
# mla
sources += ["gpu_ops/multi_head_latent_attention.cu"]
# gemm_dequant # gemm_dequant
sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"] sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
# speculate_decoding # speculate_decoding
@@ -441,6 +446,10 @@ elif paddle.is_compiled_with_cuda():
sources += find_end_files(fp8_auto_gen_directory, ".cu") sources += find_end_files(fp8_auto_gen_directory, ".cu")
if cc >= 90 and nvcc_version >= 12.0:
# Hopper optmized mla
sources += find_end_files("gpu_ops/mla_attn", ".cu")
setup( setup(
name="fastdeploy_ops", name="fastdeploy_ops",
ext_modules=CUDAExtension( ext_modules=CUDAExtension(

View File

@@ -15,6 +15,8 @@
""" """
import os import os
import subprocess
import sys
# suppress warning log from paddlepaddle # suppress warning log from paddlepaddle
os.environ["GLOG_minloglevel"] = "2" os.environ["GLOG_minloglevel"] = "2"
@@ -30,3 +32,48 @@ try:
use_triton_in_paddle.make_triton_compatible_with_paddle() use_triton_in_paddle.make_triton_compatible_with_paddle()
except ImportError: except ImportError:
pass pass
# TODO(tangbinhan): remove this code
def _patch_fastsafetensors():
try:
file_path = subprocess.check_output([
sys.executable, "-c", "import fastsafetensors, os; \
print(os.path.join(os.path.dirname(fastsafetensors.__file__), \
'frameworks', '_paddle.py'))"
]).decode().strip()
with open(file_path, 'r') as f:
content = f.read()
if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content:
return
modified = False
if "DType.U16: DType.BF16," not in content:
lines = content.splitlines()
new_lines = []
inside_block = False
for line in lines:
new_lines.append(line)
if 'need_workaround_dtypes: Dict[DType, DType] = {' in line:
inside_block = True
elif inside_block and '}' in line:
new_lines.insert(-1, ' DType.U16: DType.BF16,')
inside_block = False
modified = True
content = "\n".join(new_lines)
if "DType.I8: paddle.uint8," in content:
content = content.replace("DType.I8: paddle.uint8,",
"DType.U8: paddle.uint8,")
modified = True
if modified:
with open(file_path, 'w') as f:
f.write(content + "\n")
except Exception as e:
print(f"Failed to patch fastsafetensors: {e}")
_patch_fastsafetensors()

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Literal
from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -51,7 +51,6 @@ class ModelConfig(PretrainedConfig):
top_p = 0.0 top_p = 0.0
temperature = 1.0 temperature = 1.0
rope_theta = 10000.0 rope_theta = 10000.0
rope_scaling = None
penalty_score = 1.0 penalty_score = 1.0
frequency_score = 0.0 frequency_score = 0.0
presence_score = 0.0 presence_score = 0.0
@@ -142,6 +141,7 @@ class MoEConfig:
moe_num_shared_experts = (0, ) moe_num_shared_experts = (0, )
moe_layer_start_index = 0 moe_layer_start_index = 0
moe_layer_end_index = None moe_layer_end_index = None
moe_use_aux_free: bool = False
num_max_dispatch_tokens_per_rank = 256 num_max_dispatch_tokens_per_rank = 256
im_patch_id = ( im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json 100295 # multimodality, TODO(liuyuanle): read from config.json
@@ -163,7 +163,6 @@ class ParallelConfig:
# The embedding weight distributed on your gpu cards is divided by row or column. # The embedding weight distributed on your gpu cards is divided by row or column.
# Defaults to False means divide by row. When vocab_size can not be divided by world_size # Defaults to False means divide by row. When vocab_size can not be divided by world_size
# but hidden_size can, we can consider split embedding weight by column. # but hidden_size can, we can consider split embedding weight by column.
column_cut = False # (bool, optional)
""" """
From old wersion worker args From old wersion worker args
TODO(gongshaotian): Reclassify TODO(gongshaotian): Reclassify
@@ -194,18 +193,13 @@ class ParallelConfig:
engine_pid: Optional[int] = None engine_pid: Optional[int] = None
# Do profile or not # Do profile or not
do_profile: bool = False do_profile: bool = False
# Dynamic load weight or not
dynamic_load_weight: bool = False
# #
pad_token_id: int = -1 pad_token_id: int = -1
# #
eos_tokens_lens: int = 2 eos_tokens_lens: int = 2
# Enable chunked prefill # Enable chunked prefill
enable_chunked_prefill: str = "store_true" enable_chunked_prefill: str = "store_true"
""" #
- APPEND_ATTN:
"""
attention_backend: str = "APPEND_ATTN"
max_num_batched_tokens: int = 2048 max_num_batched_tokens: int = 2048
# enable prefix cache # enable prefix cache
enable_prefix_caching = None enable_prefix_caching = None
@@ -354,9 +348,27 @@ class GraphOptimizationConfig:
@dataclass @dataclass
class LoadConfig: class LoadConfig:
""" """
Configuration for loading parameter Configuration for dynamic weight loading strategies
Attributes:
dynamic_load_weight: Whether to enable dynamic weight loading
load_strategy: Specifies the weight loading method when enabled:
- 'ipc': Real-time IPC streaming with automatic resharding
- 'ipc_no_reshard': Real-time IPC streaming without weight process
- 'ipc_snapshot': Load from disk snapshot of IPC weights
- 'meta': provide RL traing worker, no_weights_load
- None: No dynamic loading
""" """
pass use_fastsafetensor: bool = False
dynamic_load_weight: bool = False
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
def __post_init__(self):
if self.load_strategy is not None and not self.dynamic_load_weight:
raise ValueError("Load strategy requires dynamic_load_weight=True")
if self.dynamic_load_weight and self.load_strategy is None:
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")
@dataclass @dataclass
@@ -392,7 +404,7 @@ class FDConfig:
init=True) # type: ignore init=True) # type: ignore
device_config: DeviceConfig = field(default=None, device_config: DeviceConfig = field(default=None,
init=True) # type: ignore init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None graph_opt_config: Optional[GraphOptimizationConfig] = None
moe_config: MoEConfig = field(default=None, init=True) # type: ignore moe_config: MoEConfig = field(default=None, init=True) # type: ignore

View File

@@ -16,48 +16,54 @@
import time import time
import os import os
import subprocess import multiprocessing
import signal
from fastdeploy.entrypoints.llm import LLM from fastdeploy.entrypoints.llm import LLM
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
model_name_or_path = "./models/eb45t02/"
model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle"
prefill_cmd = (f"FD_LOG_DIR=log_prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python fastdeploy.entrypoints.openai.api_server.py" def start_decode(model_name_or_path):
+ f" --model {model_name_or_path} --port 9811" os.environ["CUDA_VISIBLE_DEVICES"] = "1"
+ f" --splitwise-role prefill --tensor-parallel-size 4"
+ f" --engine-worker-queue-port 6676 --cache-queue-port 55663")
prefill_instance = subprocess.Popen(
prefill_cmd,
stdout=subprocess.PIPE,
shell=True,
preexec_fn=os.setsid,
)
# # 超参设置
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
os.environ["FD_LOG_DIR"] = "log_decode" os.environ["FD_LOG_DIR"] = "log_decode"
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm_decode = LLM( llm_decode = LLM(
model=model_name_or_path, model=model_name_or_path,
tensor_parallel_size=4, tensor_parallel_size=1,
splitwise_role="decode", splitwise_role="decode",
engine_worker_queue_port=6678, engine_worker_queue_port=6678,
innode_prefill_ports=[6676], innode_prefill_ports=[6676],
cache_queue_port=55668 cache_queue_port=55668
) )
return llm_decode
def start_prefill(model_name_or_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["FD_LOG_DIR"] = "log_prefill"
llm_prefill = LLM(
model=model_name_or_path,
tensor_parallel_size=1,
splitwise_role="prefill",
engine_worker_queue_port=6677,
cache_queue_port=55667,
)
def main():
prefill = multiprocessing.Process(
target=start_prefill,
args=(model_name_or_path,)).start()
time.sleep(10)
llm_decode = start_decode(model_name_or_path)
output = llm_decode.generate(prompts=["who are you", "what can you do"], use_tqdm=True) output = llm_decode.generate(prompts=["who are you", "what can you do"], use_tqdm=True)
print(output) print(output)
decode.join()
os.killpg(prefill_instance.pid, signal.SIGTERM)
if __name__ == "__main__":
main()

View File

@@ -87,10 +87,14 @@ class EngineArgs:
""" """
Configuration for speculative execution. Configuration for speculative execution.
""" """
dynamic_load_weight: int = 0 dynamic_load_weight: bool = False
""" """
dynamic load weight dynamic load weight
""" """
load_strategy: str = "meta"
"""
dynamic load weight strategy
"""
quantization: str = None quantization: str = None
guided_decoding_backend: str = "off" guided_decoding_backend: str = "off"
""" """
@@ -364,13 +368,16 @@ class EngineArgs:
type=json.loads, type=json.loads,
default=EngineArgs.speculative_config, default=EngineArgs.speculative_config,
help="Configuration for speculative execution.") help="Configuration for speculative execution.")
model_group.add_argument( model_group.add_argument(
"--dynamic-load-weight", "--dynamic-load-weight",
type=int, action='store_true',
default=EngineArgs.dynamic_load_weight, default=EngineArgs.dynamic_load_weight,
help="Flag to indicate whether to load weight dynamically.") help="Flag to indicate whether to load weight dynamically.")
model_group.add_argument(
"--load-strategy",
type=str,
default=EngineArgs.load_strategy,
help="Flag to dynamic load strategy.")
model_group.add_argument("--engine-worker-queue-port", model_group.add_argument("--engine-worker-queue-port",
type=int, type=int,
default=EngineArgs.engine_worker_queue_port, default=EngineArgs.engine_worker_queue_port,
@@ -383,6 +390,7 @@ class EngineArgs:
"default is None. The priority of this configuration "\ "default is None. The priority of this configuration "\
"is lower than that of the config file. " \ "is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.") "More complex quantization methods need to be configured via the config file.")
model_group.add_argument( model_group.add_argument(
"--enable-static-graph-inference", "--enable-static-graph-inference",
action='store_true', action='store_true',
@@ -668,8 +676,9 @@ class EngineArgs:
""" """
return ModelConfig(model_name_or_path=self.model, return ModelConfig(model_name_or_path=self.model,
config_json_file=self.model_config_name, config_json_file=self.model_config_name,
quantization=self.quantization,
dynamic_load_weight=self.dynamic_load_weight, dynamic_load_weight=self.dynamic_load_weight,
quantization=self.quantization) load_strategy=self.load_strategy)
def create_cache_config(self, model_cfg) -> CacheConfig: def create_cache_config(self, model_cfg) -> CacheConfig:
""" """
@@ -749,6 +758,9 @@ class EngineArgs:
speculative_cfg = self.create_speculative_config() speculative_cfg = self.create_speculative_config()
assert not (self.use_cudagraph and self.enable_prefix_caching), \
"Prefix caching cannot be used with CUDA graph"
return Config( return Config(
model_name_or_path=self.model, model_name_or_path=self.model,
model_config=model_cfg, model_config=model_cfg,

View File

@@ -41,7 +41,8 @@ class ModelConfig:
def __init__(self, def __init__(self,
model_name_or_path: str, model_name_or_path: str,
config_json_file: str = "config.json", config_json_file: str = "config.json",
dynamic_load_weight: int = 0, dynamic_load_weight: bool = False,
load_strategy: str="meta",
quantization: str = None, quantization: str = None,
download_dir: Optional[str] = None): download_dir: Optional[str] = None):
""" """
@@ -55,6 +56,7 @@ class ModelConfig:
self.model_dir = model_name_or_path self.model_dir = model_name_or_path
self.is_unified_ckpt = check_unified_ckpt(self.model_dir) self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
self.dynamic_load_weight = dynamic_load_weight self.dynamic_load_weight = dynamic_load_weight
self.load_strategy = load_strategy
self.quantization = quantization self.quantization = quantization
config_file = os.path.join(model_name_or_path, config_json_file) config_file = os.path.join(model_name_or_path, config_json_file)
@@ -584,13 +586,11 @@ class Config:
self.guided_decoding_backend = guided_decoding_backend self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace self.disable_any_whitespace = disable_any_whitespace
if self.innode_prefill_ports is not None: if self.innode_prefill_ports is not None:
if not isinstance(self.innode_prefill_ports, list): if not isinstance(self.innode_prefill_ports, list):
ports = str(self.innode_prefill_ports).split(',') ports = str(self.innode_prefill_ports).split(',')
self.innode_prefill_ports = [int(port) for port in ports] self.innode_prefill_ports = [int(port) for port in ports]
assert self.splitwise_role in ["mixed", "prefill", "decode"] assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO # TODO
@@ -728,7 +728,7 @@ class Config:
), "XPU currently do not support guided_decoding" ), "XPU currently do not support guided_decoding"
try: try:
pass import xgrammar
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}" f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"

View File

@@ -286,6 +286,8 @@ class LLMEngine(object):
while self.running: while self.running:
try: try:
results = self.scheduler.get_results() results = self.scheduler.get_results()
if len(results) == 0:
time.sleep(0.001)
for request_id, contents in results.items(): for request_id, contents in results.items():
for result in contents: for result in contents:
self.zmq_server.send_multipart(request_id, result) self.zmq_server.send_multipart(request_id, result)
@@ -444,8 +446,8 @@ class LLMEngine(object):
enable_thinking = None enable_thinking = None
if kwargs is not None: if kwargs is not None:
enable_thinking = kwargs.get("enable_thinking", None) enable_thinking = kwargs.get("enable_thinking", None)
request = self.data_processor.process_request(request, request = self.data_processor.process_request(
self.cfg.max_model_len, enable_thinking=enable_thinking) request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request.prompt_token_ids_len = len(request.prompt_token_ids) request.prompt_token_ids_len = len(request.prompt_token_ids)
input_ids_len = request.prompt_token_ids_len input_ids_len = request.prompt_token_ids_len
request.set( request.set(
@@ -453,7 +455,8 @@ class LLMEngine(object):
min(self.cfg.max_model_len - input_ids_len, min(self.cfg.max_model_len - input_ids_len,
request.get("max_tokens"))) request.get("max_tokens")))
if request.get("reasoning_max_tokens") is None: if request.get("reasoning_max_tokens") is None:
default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1) default_reasoning_max_tokens = max(
int(request.get("max_tokens") * 0.8), 1)
request.set("reasoning_max_tokens", default_reasoning_max_tokens) request.set("reasoning_max_tokens", default_reasoning_max_tokens)
min_tokens = request.get("min_tokens") min_tokens = request.get("min_tokens")
if input_ids_len + min_tokens >= self.cfg.max_model_len: if input_ids_len + min_tokens >= self.cfg.max_model_len:
@@ -963,8 +966,8 @@ class LLMEngine(object):
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"FLAGS_use_append_attn": 1, "FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring", "NCCL_ALGO": "Ring",
"FLAGS_hardamard_moe_block_size": 128,
"FLAGS_max_partition_size": 32768, "FLAGS_max_partition_size": 32768,
"FLAGS_hardamard_moe_block_size": 128,
} }
# environment variables needed by Dy2St # environment variables needed by Dy2St
variables.update({ variables.update({
@@ -1017,6 +1020,12 @@ class LLMEngine(object):
worker_path = "../worker/vl_worker_process.py" worker_path = "../worker/vl_worker_process.py"
py_script = os.path.join(current_dir_path, worker_path) py_script = os.path.join(current_dir_path, worker_path)
ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model')
else len(self.data_processor.tokenizer.vocab)
)
arguments = ( arguments = (
f" --nnodes {str(self.cfg.nnode)}" f" --nnodes {str(self.cfg.nnode)}"
f" --devices {self.cfg.device_ids} {py_script}" f" --devices {self.cfg.device_ids} {py_script}"
@@ -1037,13 +1046,14 @@ class LLMEngine(object):
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --quantization {self.cfg.model_config.quantization}" f" --quantization {self.cfg.model_config.quantization}"
f" --ori_vocab_size {len(self.data_processor.tokenizer)}" f" --ori_vocab_size {ori_vocab_size}"
f" --speculative_method {self.cfg.speculative_config.method}" f" --speculative_method {self.cfg.speculative_config.method}"
f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}" f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}"
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}" f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}" f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}" f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}") f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}")
worker_append_flag = { worker_append_flag = {
"enable_expert_parallel": "enable_expert_parallel":
@@ -1188,7 +1198,8 @@ class LLMEngine(object):
line = line.decode('utf-8', errors='ignore') line = line.decode('utf-8', errors='ignore')
if self.worker_init_status.get("finished", False): if self.worker_init_status.get("finished", False):
break break
if match := re.search(r'Loading checkpoint shards:\s*(\d+)', if match := re.search(
r'Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)',
line): line):
self.worker_init_status["weight_loadding"] = eval( self.worker_init_status["weight_loadding"] = eval(
match.group(1)) * 1.0 / 100 match.group(1)) * 1.0 / 100

View File

@@ -221,6 +221,9 @@ class OpenAIServingChat:
else: else:
choice.finish_reason = "length" choice.finish_reason = "length"
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
choice.finish_reason = "length"
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "": if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
choice.delta.token_ids = output["token_ids"] choice.delta.token_ids = output["token_ids"]
if include_continuous_usage: if include_continuous_usage:
@@ -335,6 +338,9 @@ class OpenAIServingChat:
choice.finish_reason = "tool_calls" choice.finish_reason = "tool_calls"
else: else:
choice.finish_reason = "length" choice.finish_reason = "length"
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
choice.finish_reason = "length"
choices.append(choice) choices.append(choice)
num_prompt_tokens = len(prompt_token_ids) num_prompt_tokens = len(prompt_token_ids)

View File

@@ -82,13 +82,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_MOE_BACKEND": "FD_MOE_BACKEND":
lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
# Set whether to disable recompute the request when the KV cache is full.
"FD_DISABLED_RECOVER":
lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
# Set triton kernel JIT compilation directory. # Set triton kernel JIT compilation directory.
"FD_TRITON_KERNEL_CACHE_DIR": "FD_TRITON_KERNEL_CACHE_DIR":
lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None), lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
# Whether transition from standalone PD decoupling to centralized inference # Whether transition from standalone PD decoupling to centralized inference
"FD_PD_CHANGEABLE": "FD_PD_CHANGEABLE":
lambda: os.getenv("FD_PD_CHANGEABLE", "1"), lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
# Whether to use fastsafetensor load weight (0 or 1)
"FD_USE_FASTSAFETENSOR":
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
} }

View File

@@ -27,6 +27,7 @@ from fastdeploy.input.text_processor import BaseDataProcessor
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class ErnieProcessor(BaseDataProcessor): class ErnieProcessor(BaseDataProcessor):
""" """
初始化模型实例。 初始化模型实例。
@@ -160,6 +161,7 @@ class ErnieProcessor(BaseDataProcessor):
if request.get('prompt'): if request.get('prompt'):
prompt = request.get('prompt') prompt = request.get('prompt')
prompt = prompt[0] if isinstance(prompt, list) else prompt prompt = prompt[0] if isinstance(prompt, list) else prompt
tokens = self.tokenizer.tokenize(prompt) tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request['prompt_token_ids'] = token_ids request['prompt_token_ids'] = token_ids

View File

@@ -82,6 +82,8 @@ class ErnieBotTokenizer(PretrainedTokenizer):
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file) self.sp_model.Load(vocab_file)
# pre-process map-type all spec token for decode accelerate.
self.all_spec_tok = set(self.all_special_tokens)
@property @property
def space_token(self): def space_token(self):
@@ -143,7 +145,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
# prev_is_special = False # prev_is_special = False
for token in tokens: for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model # make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens: if token in self.all_spec_tok:
# if not prev_is_special: # if not prev_is_special:
# out_string += " " # out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token out_string += self.sp_model.decode(current_sub_tokens) + token
@@ -216,7 +218,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
if hasattr(self, "do_lower_case") and self.do_lower_case: if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase # convert non-special tokens to lowercase
escaped_special_toks = [ escaped_special_toks = [
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens) re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)
] ]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)

View File

@@ -25,6 +25,7 @@ from fastdeploy.utils import data_processor_logger
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class BaseDataProcessor(ABC): class BaseDataProcessor(ABC):
"""base class for data processor""" """base class for data processor"""

View File

@@ -16,10 +16,12 @@ from .attention import Attention
from .append_attn_backend import AppendAttentionBackend from .append_attn_backend import AppendAttentionBackend
from .attention_selecter import get_attention_backend from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend from .base_attention_backend import AttentionBackend
from .mla_attention_backend import MLAAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend from .xpu_attn_backend import XPUAttentionBackend
__all__ = [ __all__ = [
"Attention", "AttentionBackend", "PaddleNativeAttnBackend", "Attention", "AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend" "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend"
] ]

View File

@@ -187,6 +187,8 @@ class AppendAttentionBackend(AttentionBackend):
k: paddle.Tensor, k: paddle.Tensor,
v: paddle.Tensor, v: paddle.Tensor,
qkv: paddle.Tensor, qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention, layer: Attention,
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
) -> paddle.Tensor: ) -> paddle.Tensor:

View File

@@ -111,6 +111,8 @@ class Attention(nn.Layer):
k: paddle.Tensor = None, k: paddle.Tensor = None,
v: paddle.Tensor = None, v: paddle.Tensor = None,
qkv: paddle.Tensor = None, qkv: paddle.Tensor = None,
compressed_kv: paddle.Tensor = None,
k_pe: paddle.Tensor = None,
forward_meta: ForwardMeta = None, forward_meta: ForwardMeta = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
@@ -120,12 +122,16 @@ class Attention(nn.Layer):
k: the key tensor k: the key tensor
v: the value tensor v: the value tensor
forward_meta: the forward meta data forward_meta: the forward meta data
compressed_kv: optional compressed key-value cache (for MLA)
k_pe: optional key positional encoding (for MLA)
""" """
return forward_meta.attn_backend.forward( return forward_meta.attn_backend.forward(
q, q,
k, k,
v, v,
qkv, qkv,
compressed_kv,
k_pe,
self, self,
forward_meta, forward_meta,
) )

View File

@@ -16,6 +16,7 @@
from functools import cache from functools import cache
from fastdeploy import envs
from fastdeploy.platforms import _Backend, current_platform from fastdeploy.platforms import _Backend, current_platform
from fastdeploy.utils import resolve_obj_from_strname from fastdeploy.utils import resolve_obj_from_strname
@@ -40,6 +41,7 @@ def _get_attn_backend(selected_backend: str) -> object:
return resolve_obj_from_strname(attention_cls) return resolve_obj_from_strname(attention_cls)
def get_attention_backend(selected_backend): def get_attention_backend() -> object:
"""Selects which attention backend.""" """Selects which attention backend."""
return _get_attn_backend(selected_backend) attention_backend = envs.FD_ATTENTION_BACKEND
return _get_attn_backend(attention_backend)

View File

@@ -46,6 +46,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor, k: paddle.Tensor,
v: paddle.Tensor, v: paddle.Tensor,
qkv: paddle.Tensor, qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer, layer: paddle.nn.Layer,
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
) -> paddle.Tensor: ) -> paddle.Tensor:
@@ -56,6 +58,8 @@ class AttentionBackend(ABC):
k: The key tensor. k: The key tensor.
v: The value tensor. v: The value tensor.
layer: The layer that will be used for the forward. layer: The layer that will be used for the forward.
compressed_kv: optional compressed key-value cache (for MLA)
k_pe: optional key positional encoding (for MLA)
forward_meta: The forward metadata. forward_meta: The forward metadata.
""" """
if forward_meta.forward_mode.is_mixed(): if forward_meta.forward_mode.is_mixed():
@@ -64,6 +68,8 @@ class AttentionBackend(ABC):
k, k,
v, v,
qkv, qkv,
compressed_kv,
k_pe,
layer, layer,
forward_meta, forward_meta,
) )
@@ -73,6 +79,8 @@ class AttentionBackend(ABC):
k, k,
v, v,
qkv, qkv,
compressed_kv,
k_pe,
layer, layer,
forward_meta, forward_meta,
) )
@@ -82,6 +90,8 @@ class AttentionBackend(ABC):
k, k,
v, v,
qkv, qkv,
compressed_kv,
k_pe,
layer, layer,
forward_meta, forward_meta,
) )
@@ -92,6 +102,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor, k: paddle.Tensor,
v: paddle.Tensor, v: paddle.Tensor,
qkv: paddle.Tensor, qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer, layer: paddle.nn.Layer,
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
) -> paddle.Tensor: ) -> paddle.Tensor:
@@ -104,6 +116,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor, k: paddle.Tensor,
v: paddle.Tensor, v: paddle.Tensor,
qkv: paddle.Tensor, qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer, layer: paddle.nn.Layer,
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
) -> paddle.Tensor: ) -> paddle.Tensor:
@@ -116,6 +130,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor, k: paddle.Tensor,
v: paddle.Tensor, v: paddle.Tensor,
qkv: paddle.Tensor, qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer, layer: paddle.nn.Layer,
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
) -> paddle.Tensor: ) -> paddle.Tensor:

View File

@@ -0,0 +1,490 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import math
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple
import paddle
from paddle.nn.functional.flash_attention import flash_attn_unpadded
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, init_signal_layerwise,
open_shm_and_get_meta_signal)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache,
multi_head_latent_attention,
prefill_mla_write_cache)
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
@dataclass
class MLAAttentionMetadata(AttentionMetadata):
"""
MLAAttentionMetadata for Multi-Layer Attention
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: _DTypeLiteral = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
class MLAAttentionBackend(AttentionBackend):
"""
MLA Attention Backend implementation.
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int) -> None:
"""
MLAAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: MLAAttentionMetadata = None
# 基础配置
self.block_size: int = fd_config.parallel_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = (10000.0
if fd_config.model_config.rope_theta is None
else fd_config.model_config.rope_theta)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_layers
# For Multi Head Latent Attention
self.kv_lora_rank: int = fd_config.model_config.deepseekv3.kv_lora_rank
self.qk_rope_head_dim: int = fd_config.model_config.deepseekv3.qk_rope_head_dim
self.qk_head_dim: int = fd_config.model_config.deepseekv3.qk_nope_head_dim \
+ fd_config.model_config.deepseekv3.qk_rope_head_dim
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
if fd_config.model_config.deepseekv3.rope_scaling:
mscale_all_dim = fd_config.model_config.deepseekv3.rope_scaling.get(
"mscale_all_dim", False) # 1.0
scaling_factor = fd_config.model_config.deepseekv3.rope_scaling[
"factor"] # 40
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if self.device_id is None:
self.device_id = self.rank
else:
self.device_id = self.device_id.split(",")[self.rank]
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = self.max_seq_len
metadata._dtype = paddle.get_default_dtype()
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# MLA
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata: AttentionMetadata = metadata
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(self,
max_num_blocks: int) -> Tuple[int, int, int, int]:
"""
Calculate kv cache shape for MLA
"""
return (max_num_blocks, 1, self.block_size,
self.kv_lora_rank + self.qk_rope_head_dim)
def forward_extend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Prefill阶段的前向传播
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
forward_meta, 'caches') else None
# 写入缓存
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
getattr(forward_meta, 'max_input_length', -1),
)
# Flash注意力计算
fmha_out = flash_attn_unpadded(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
)[0]
return fmha_out
def forward_decode(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Decode阶段的前向传播
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
forward_meta, 'caches') else None
# 获取推测解码参数
speculate_decoder = self.speculative_method is not None
speculate_max_tokens = self.speculate_max_draft_token_num
# 写入缓存
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)
# 多头潜在注意力计算
fmha_out = multi_head_latent_attention(
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # out_shifts
None, # out_smooths
metadata._fuse_kernel_compute_dtype,
"none", # cache_quant_type
self.kv_lora_rank,
self.max_seq_len,
self.attn_softmax_scale,
0.0, # quant_max_bound
0.0, # quant_min_bound
0.0, # out_linear_in_scale
speculate_max_tokens,
True, # causal
speculate_decoder,
)
return fmha_out
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Mixed模式的前向传播
"""
metadata = self.attention_metadata
speculate_decoder = self.speculative_method is not None
speculate_max_tokens = self.speculate_max_draft_token_num
decode_stage = forward_meta.is_decode_batch
prefill_stage = not (forward_meta.is_decode_batch)
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
forward_meta, 'caches') else None
if prefill_stage:
# 写入缓存
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
self.max_seq_len,
)
# FA
fmha_out = flash_attn_unpadded(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
)[0]
return fmha_out
# Decode
if decode_stage:
# mla写入缓存
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)
# 多头潜在注意力计算
fmha_out = multi_head_latent_attention(
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # out_shifts
None, # out_smooths
metadata._fuse_kernel_compute_dtype,
"none", # cache_quant_type
self.kv_lora_rank,
self.max_seq_len,
self.attn_softmax_scale,
0.0, # quant_max_bound
0.0, # quant_min_bound
0.0, # out_linear_in_scale
speculate_max_tokens,
True, # causal
speculate_decoder,
)
return fmha_out

View File

@@ -149,6 +149,8 @@ class XPUAttentionBackend(AttentionBackend):
k: paddle.Tensor, k: paddle.Tensor,
v: paddle.Tensor, v: paddle.Tensor,
qkv: paddle.Tensor, qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention, layer: Attention,
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
) -> paddle.Tensor: ) -> paddle.Tensor:

View File

@@ -41,16 +41,12 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
""" """
Create weights for linear layer on XPU Create weights for linear layer on XPU
""" """
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse() layer.linear_weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4": if self.quant_config.name() == "weight_only_int4":
layer.linear_weight_shape[0] //= 2 layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8" layer.weight_dtype = "int8"
linear_weight_scale_shape = [layer.embed_dim]
if hasattr(layer, "linear_weight_shape"):
if isinstance(layer.linear_weight_shape, list):
layer_weight_shape = layer.linear_weight_shape
linear_weight_scale_shape = layer_weight_shape[:1]
layer.linear_weight_scale = layer.create_parameter( layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape, shape=linear_weight_scale_shape,
dtype="float32", dtype="float32",

View File

@@ -14,10 +14,15 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Dict
import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from .utils import get_tensor from .utils import get_tensor
@@ -28,12 +33,12 @@ class VocabParallelEmbedding(nn.Layer):
def __init__( def __init__(
self, self,
fd_config, fd_config: FDConfig,
num_embeddings, num_embeddings: int,
embedding_dim=768, embedding_dim: int = 768,
params_dtype="bfloat16", params_dtype: str = "bfloat16",
prefix="", prefix="",
): ) -> None:
""" """
Initialize the VocabParallelEmbedding layer for the model. Initialize the VocabParallelEmbedding layer for the model.
@@ -41,28 +46,28 @@ class VocabParallelEmbedding(nn.Layer):
fd_config (FDConfig): Arguments related to inference, containing fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim, attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size. num_attention_heads, and ffn_hidden_size.
num_embeddings : vocabulary size. num_embeddings (int) : vocabulary size.
embedding_dim : size of hidden state. embedding_dim (int) : size of hidden state.
params_dtype : data type of parameters. params_dtype (str) : data type of parameters.
prefix (str): Unique name of the layer, used for naming internal attributes, prefix (str): The name of current layer. Defaults to "".
you can give it any name you like.
""" """
super().__init__() super().__init__()
self.fd_config = fd_config self.fd_config = fd_config
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
self.mp_rank = hcg.get_model_parallel_rank() self.mp_rank: int = hcg.get_model_parallel_rank()
self.column_cut = fd_config.parallel_config.column_cut self.column_cut = False
self.world_size = hcg.get_model_parallel_world_size() self.world_size: int = hcg.get_model_parallel_world_size()
self.ring_id = hcg.get_model_parallel_group().id self.ring_id: int = hcg.get_model_parallel_group().id
self.use_rope = fd_config.model_config.use_rope self.use_rope: bool = fd_config.model_config.use_rope
self.rope_head_dim = fd_config.model_config.rope_head_dim self.rope_head_dim: int = fd_config.model_config.rope_head_dim
self.use_ep = fd_config.parallel_config.use_ep self.use_ep: bool = fd_config.parallel_config.use_ep
self.hidden_dropout_prob = fd_config.model_config.hidden_dropout_prob self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range = fd_config.model_config.initializer_range self.initializer_range: float = fd_config.model_config.initializer_range
self.sequence_parallel = fd_config.parallel_config.sequence_parallel self.sequence_parallel: bool = fd_config.parallel_config.sequence_parallel
self.max_position_embeddings = fd_config.model_config.max_position_embeddings self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.freeze_embedding = fd_config.model_config.freeze_embedding self.freeze_embedding: bool = fd_config.model_config.freeze_embedding
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
if self.use_ep: if self.use_ep:
self.word_embeddings = nn.Embedding( self.word_embeddings = nn.Embedding(
@@ -109,7 +114,8 @@ class VocabParallelEmbedding(nn.Layer):
self.rope_head_dim_shape_tensor = paddle.ones((self.rope_head_dim), self.rope_head_dim_shape_tensor = paddle.ones((self.rope_head_dim),
dtype="int8") dtype="int8")
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
""" """
Load the checkpoint state dictionary into the layer. Load the checkpoint state dictionary into the layer.
@@ -125,7 +131,7 @@ class VocabParallelEmbedding(nn.Layer):
get_tensor(state_dict.pop(self.prefix + ".weight")).astype( get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
paddle.get_default_dtype())) paddle.get_default_dtype()))
def forward(self, ids_remove_padding=None): def forward(self, ids_remove_padding=None) -> paddle.Tensor:
""" """
Defines the forward computation of the layer. Defines the forward computation of the layer.

View File

@@ -216,6 +216,14 @@ class ReplicatedLinear(LinearBase):
with_bias=with_bias, with_bias=with_bias,
add_bias=add_bias, add_bias=add_bias,
skip_quant=skip_quant) skip_quant=skip_quant)
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [
self.input_size,
self.output_size,
]
if fd_config.quant_config:
self.quant_method.create_weights(self)
self.init_weight() self.init_weight()
@@ -259,7 +267,10 @@ class ColumnParallelLinear(LinearBase):
skip_quant=skip_quant) skip_quant=skip_quant)
self.nranks = fd_config.parallel_config.tensor_parallel_degree self.nranks = fd_config.parallel_config.tensor_parallel_degree
self.input_size = input_size self.input_size = input_size
self.output_size = divide(output_size, self.nranks) self.output_size = divide(
output_size,
self.nranks) # Split the output_size using TP inference.
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [ self.linear_weight_shape = [
self.input_size, self.input_size,
self.output_size, self.output_size,
@@ -339,7 +350,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
""" """
self.use_fast_ffn = use_fast_ffn self.use_fast_ffn = use_fast_ffn
self.activation = activation self.activation = activation
self.embed_dim = fd_config.model_config.hidden_size self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_degree self.nranks = fd_config.parallel_config.tensor_parallel_degree
super().__init__(fd_config=fd_config, super().__init__(fd_config=fd_config,
@@ -413,12 +424,12 @@ class QKVParallelLinear(ColumnParallelLinear):
""" """
self.num_heads = fd_config.model_config.num_attention_heads self.num_heads = fd_config.model_config.num_attention_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads self.kv_num_heads = fd_config.model_config.num_key_value_heads
self.embed_dim = fd_config.model_config.hidden_size self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim self.head_dim = fd_config.model_config.head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_degree self.nranks = fd_config.parallel_config.tensor_parallel_degree
self.num_heads_per_rank = divide(self.num_heads, self.nranks) self.num_heads_per_rank = divide(self.num_heads, self.nranks)
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks) self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
input_size = self.embed_dim input_size = self.hidden_size
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
super().__init__(fd_config=fd_config, super().__init__(fd_config=fd_config,
prefix=prefix, prefix=prefix,
@@ -448,7 +459,7 @@ class QKVParallelLinear(ColumnParallelLinear):
weight_tensor = weight_tensor.reshape([ weight_tensor = weight_tensor.reshape([
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) *
(self.head_dim), (self.head_dim),
self.embed_dim, self.hidden_size,
]) ])
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0]) weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
@@ -513,6 +524,7 @@ class RowParallelLinear(LinearBase):
output_size: int = None, output_size: int = None,
with_bias: bool = False, with_bias: bool = False,
add_bias: bool = False, add_bias: bool = False,
reduce_results: bool = True,
skip_quant: bool = False, skip_quant: bool = False,
): ):
""" """
@@ -538,10 +550,14 @@ class RowParallelLinear(LinearBase):
self.fd_config = fd_config self.fd_config = fd_config
self.skip_quant = False self.skip_quant = False
self.nranks = fd_config.parallel_config.tensor_parallel_degree self.nranks = fd_config.parallel_config.tensor_parallel_degree
self.embed_dim = fd_config.model_config.hidden_size self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
# Split input_size when using TP inference.
self.input_size = divide(input_size, self.nranks)
self.output_size = output_size
self.linear_weight_shape = [ self.linear_weight_shape = [
self.input_size, self.input_size,
self.output_size, self.output_size,
@@ -551,6 +567,8 @@ class RowParallelLinear(LinearBase):
if fd_config.quant_config: if fd_config.quant_config:
self.quant_method = fd_config.quant_config.get_quant_method(self) self.quant_method = fd_config.quant_config.get_quant_method(self)
self.quant_method.create_weights(self) self.quant_method.create_weights(self)
self.reduce_results = reduce_results
self.init_weight() self.init_weight()
def init_weight(self): def init_weight(self):
@@ -570,7 +588,7 @@ class RowParallelLinear(LinearBase):
self.linear_bias = None self.linear_bias = None
if self.with_bias: if self.with_bias:
self.linear_bias = self.create_parameter( self.linear_bias = self.create_parameter(
shape=[self.embed_dim], shape=[self.hidden_size],
dtype=self._dtype, dtype=self._dtype,
is_bias=True, is_bias=True,
) )
@@ -589,7 +607,7 @@ class RowParallelLinear(LinearBase):
else: else:
out = paddle.matmul(x, self.linear_weight) out = paddle.matmul(x, self.linear_weight)
if self.nranks > 1: if self.reduce_results and self.nranks > 1:
tensor_model_parallel_all_reduce(out) tensor_model_parallel_all_reduce(out)
return out return out

View File

@@ -14,10 +14,15 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Dict, Optional
import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from .utils import get_tensor from .utils import get_tensor
@@ -28,12 +33,12 @@ class ParallelLMHead(nn.Layer):
def __init__( def __init__(
self, self,
fd_config, fd_config: FDConfig,
num_embeddings, num_embeddings: int,
embedding_dim, embedding_dim: int,
prefix="", prefix: str = "",
with_bias=False, with_bias: bool = False,
): ) -> None:
""" """
Parallelized LMhead. Parallelized LMhead.
@@ -43,21 +48,22 @@ class ParallelLMHead(nn.Layer):
num_attention_heads, and ffn_hidden_size. num_attention_heads, and ffn_hidden_size.
num_embeddings (int): vocabulary size. num_embeddings (int): vocabulary size.
embedding_dim (int): size of hidden state. embedding_dim (int): size of hidden state.
prefix (str): full name of the layer in the state dict prefix (str): The name of current layer. Defaults to "".
with_bias (bool): whether to have bias. Default: False.
""" """
super(ParallelLMHead, self).__init__() super(ParallelLMHead, self).__init__()
self.linear_weight_key = prefix + ".weight" self.linear_weight_key: str = prefix + ".weight"
if with_bias: if with_bias:
self.linear_bias_key = prefix + ".bias" self.linear_bias_key: Optional[str] = prefix + ".bias"
else: else:
self.linear_bias_key = None self.linear_bias_key: Optional[str] = None
self.use_ep = fd_config.parallel_config.use_ep self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True self.column_cut = True
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
if self.use_ep: if self.use_ep:
self.weight = self.create_parameter( self.weight = self.create_parameter(
@@ -92,7 +98,8 @@ class ParallelLMHead(nn.Layer):
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False, # False diff更小
) )
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
""" """
Load the checkpoint state dictionary into the layer. Load the checkpoint state dictionary into the layer.
@@ -122,7 +129,7 @@ class ParallelLMHead(nn.Layer):
paddle.get_default_dtype()) paddle.get_default_dtype())
self.out_linear.bias.set_value(bias) self.out_linear.bias.set_value(bias)
def forward(self, input): def forward(self, input: paddle.Tensor) -> paddle.Tensor:
""" """
Defines the forward computation of the layer. Defines the forward computation of the layer.

View File

@@ -22,13 +22,34 @@ from paddleformers.utils.log import logger
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import \ from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce tensor_model_parallel_all_reduce
from ..utils import get_tensor, create_and_set_parameter from fastdeploy.platforms import current_platform
from ..utils import create_and_set_parameter, get_tensor
from .fused_moe_backend_base import MoEMethodBase from .fused_moe_backend_base import MoEMethodBase
from fastdeploy.platforms import current_platform
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce moe_expert_reduce, noaux_tc)
# used for deepseek_v3
def get_moe_scores(gating_output: paddle.Tensor, n_group, topk_group, top_k,
routed_scaling_factor,
e_score_correction_bias) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
scores = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores
class CutlassMoEMethod(MoEMethodBase): class CutlassMoEMethod(MoEMethodBase):
@@ -199,6 +220,30 @@ class CutlassMoEMethod(MoEMethodBase):
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
""" """
if layer.topk_method == "noaux_tc":
gate_out = get_moe_scores(gate_out, layer.n_group,
layer.topk_group, layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias)
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
topk_only_mode=True,
)
else:
( (
permute_input, permute_input,
token_nums_per_expert, token_nums_per_expert,
@@ -234,11 +279,11 @@ class CutlassMoEMethod(MoEMethodBase):
permute_indices_per_token, permute_indices_per_token,
topk_idx, topk_idx,
None, None,
norm_topk_prob=True, norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0, routed_scaling_factor=1.0,
) )
if layer.tp_size > 1: if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out) tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out return fused_moe_out

View File

@@ -195,8 +195,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
num_experts = layer.num_experts num_experts = layer.num_experts
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out, gate_out,
layer.gate_correction_bias, layer.gate_correction_bias,

View File

@@ -17,6 +17,7 @@
import paddle import paddle
from paddle import nn from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication_op import \ from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map, from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
@@ -25,17 +26,24 @@ from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
try:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
from .triton_moe_kernels import fused_moe_kernel_paddle
except:
pass
class TritonWeightOnlyMoEMethod(QuantMethodBase): class TritonWeightOnlyMoEMethod(QuantMethodBase):
""" """
Use Triton Group Gemm to compute Fused MoE. Use Triton Group Gemm to compute Fused MoE.
""" """
def __init__(self, quant_method=None): def __init__(self, quant_config=None):
""" """
Triton Group Gemm to compute Fused MoE. Triton Group Gemm to compute Fused MoE.
""" """
self.quant_method = quant_method self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_scale_attrs = [ self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale" "moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
@@ -52,7 +60,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts assert len(ffn2_weights) == layer.num_local_experts
assert layer.quant_method.quant_config.name() == "wint8"
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert ffn1_weights[0].shape == [ assert ffn1_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2 layer.hidden_size, layer.moe_intermediate_size * 2
] ]
@@ -63,9 +75,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
ffn1_tensor = paddle.stack(ffn1_weights, axis=0) ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0) ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
if self.quant_config.name() == "wint8": if algo == "wint8":
max_bound = 127 max_bound = 127
elif self.quant_config.name() == "wint4": elif algo == "wint4":
max_bound = 7 max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]): for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
@@ -111,15 +123,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight) topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
scores = paddle.nn.functional.softmax(gate_out, axis=-1) gate_out,
layer.gate_correction_bias,
topk_weights, topk_ids = paddle.topk(scores, top_k,
k=top_k, True, # apply_norm_weight,
axis=-1, False,
sorted=False) )
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
intermediate_cache1 = paddle.empty( intermediate_cache1 = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2], [token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype, dtype=x.dtype,
@@ -139,13 +149,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
} }
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
from .triton_moe_kernels import fused_moe_kernel_paddle
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
max_num_tokens_padded = sorted_token_ids.shape[0] max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
@@ -158,10 +166,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
moe_intermediate_size * 2, max_possible_num_post_padded,
hidden_size,
max_num_tokens_padded,
token_num * top_k, token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x.strides[0], stride_am=x.strides[0],
stride_ak=x.strides[1], stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0], stride_be=layer.moe_ffn1_weight.strides[0],
@@ -193,7 +201,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
intermediate_cache2 = paddle.incubate.nn.functional.swiglu( intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1) intermediate_cache1)
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
intermediate_cache2, intermediate_cache2,
@@ -205,10 +214,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
hidden_size, max_possible_num_post_padded,
moe_intermediate_size,
max_num_tokens_padded,
token_num * top_k, token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=intermediate_cache2.strides[0], stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1], stride_ak=intermediate_cache2.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0], stride_be=layer.moe_ffn2_weight.strides[0],
@@ -324,7 +333,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
scores = paddle.nn.functional.softmax(gate_out, axis=-1) scores = paddle.nn.functional.softmax(gate_out, axis=-1)
topk_weights, topk_ids = paddle.topk(scores, topk_weights, topk_ids = paddle.topk(scores,
@@ -352,12 +360,12 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
} }
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
max_num_tokens_padded = sorted_token_ids.shape[0] max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
adamard_matrix = create_hadamard_matrix_map[hidden_size] adamard_matrix = create_hadamard_matrix_map[hidden_size]
@@ -371,8 +379,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
permute_x = permute_x / quant_activation_scale permute_x = permute_x / quant_activation_scale
permute_x = permute_x.astype("float8_e4m3fn") permute_x = permute_x.astype("float8_e4m3fn")
from .triton_moe_kernels import fused_moe_kernel_paddle
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
permute_x, permute_x,
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn), layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
@@ -383,10 +389,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
moe_intermediate_size * 2, max_possible_num_post_padded,
hidden_size,
max_num_tokens_padded,
token_num * top_k, token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x.strides[0], stride_am=x.strides[0],
stride_ak=x.strides[1], stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0], stride_be=layer.moe_ffn1_weight.strides[0],
@@ -426,7 +432,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
intermediate_cache2 = intermediate_cache2 / quant_activation_scale intermediate_cache2 = intermediate_cache2 / quant_activation_scale
intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn") intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn")
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
@@ -439,10 +446,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
hidden_size, max_possible_num_post_padded,
moe_intermediate_size,
max_num_tokens_padded,
token_num * top_k, token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=intermediate_cache2.strides[0], stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1], stride_ak=intermediate_cache2.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0], stride_be=layer.moe_ffn2_weight.strides[0],

View File

@@ -224,6 +224,7 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod):
) )
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
fused_moe_out = moe_expert_reduce( fused_moe_out = moe_expert_reduce(
ffn_out, ffn_out,
topk_weights, topk_weights,

View File

@@ -30,10 +30,15 @@ class FusedMoE(nn.Layer):
def __init__( def __init__(
self, self,
fd_config, fd_config,
reduce_results: bool = True,
moe_intermediate_size: int = -1, moe_intermediate_size: int = -1,
num_experts: int = -1, num_experts: int = -1,
expert_id_offset: int = 0, expert_id_offset: int = 0,
top_k: int = -1, top_k: int = -1,
topk_method: str = "",
topk_group: int = -1,
n_group: int = -1,
routed_scaling_factor: float = 1.0,
layer_idx: int = -1, layer_idx: int = -1,
moe_tag: str = "", moe_tag: str = "",
weight_key_map: dict = {}, weight_key_map: dict = {},
@@ -49,6 +54,7 @@ class FusedMoE(nn.Layer):
self.fd_config = fd_config self.fd_config = fd_config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.reduce_results = reduce_results
self.tp_size = fd_config.parallel_config.tensor_parallel_degree self.tp_size = fd_config.parallel_config.tensor_parallel_degree
self.ep_size = fd_config.parallel_config.expert_parallel_degree self.ep_size = fd_config.parallel_config.expert_parallel_degree
@@ -60,28 +66,32 @@ class FusedMoE(nn.Layer):
self.hidden_size = fd_config.model_config.hidden_size self.hidden_size = fd_config.model_config.hidden_size
self.moe_config = fd_config.moe_config self.moe_config = fd_config.moe_config
self.num_experts = num_experts self.num_experts = num_experts
self.num_local_experts = self.num_experts // self.ep_size self.num_local_experts = self.num_experts // self.ep_size
self.moe_intermediate_size = moe_intermediate_size // self.tp_size self.moe_intermediate_size = moe_intermediate_size // self.tp_size
self.top_k = top_k self.top_k = top_k
self.hidden_size = self.hidden_size
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
self.weight_key_map = weight_key_map self.weight_key_map = weight_key_map
self.use_method = envs.FD_MOE_BACKEND.lower() self.use_method = envs.FD_MOE_BACKEND.lower()
self.gate_correction_bias = None self.gate_correction_bias = None
self.moe_tag = moe_tag self.moe_tag = moe_tag
if self.ep_size > 1: if self.ep_size > 1:
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
self.expert_id_offset = expert_id_offset self.expert_id_offset = expert_id_offset
if fd_config.quant_config: # used for deepseek_v3
self.quant_method = fd_config.quant_config.get_quant_method(self) self.topk_method = topk_method
self.topk_group = topk_group
self.n_group = n_group
self.routed_scaling_factor = routed_scaling_factor
moe_quant_config = fd_config.quant_config
if moe_quant_config:
self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name()
else: else:
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future # now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
from .fused_moe_cutlass_backend import CutlassMoEMethod from .fused_moe_cutlass_backend import CutlassMoEMethod
@@ -90,12 +100,78 @@ class FusedMoE(nn.Layer):
if self.ep_size > 1: if self.ep_size > 1:
self.quant_method.init_ep(self) self.quant_method.init_ep(self)
if fd_config.load_config.dynamic_load_weight:
# It's for RL to build model
self.init_moe_weights()
logger.info( logger.info(
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \ f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \ {top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
, ep_size={self.ep_size}, \ , ep_size={self.ep_size}, \
tp_size={self.tp_size}.") tp_size={self.tp_size}.")
def init_moe_weights(self):
"""
Initialize the weight shapes and parameters for the MoE layer.
Combines weight shape initialization and parameter creation into a single function.
"""
# Initialize weight shapes
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
gate_weight_shape = [self.hidden_size, self.num_experts]
gate_correction_bias_shape = [1, self.num_experts]
self.gate_weight = self.create_parameter(
shape=gate_weight_shape,
dtype="float32",
)
if self.moe_config.moe_use_aux_free:
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_shape,
dtype="float32",
)
ffn1_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["fp8", "wint8"]:
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
else:
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
# Create parameters
if self.moe_quant_type == "fp8":
#(TODO:gaoziyuan)
pass
else:
self.weight_dtype = "int8"
self.init_weight_only_scale()
# FFN1 parameters
self.moe_ffn1_weight = self.create_parameter(
shape=ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
# FFN2 parameters
self.moe_ffn2_weight = self.create_parameter(
shape=ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
def init_weight_only_scale(self):
"""
Initialize the weight scale.
"""
self.moe_ffn1_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
dtype=self._dtype,
)
self.moe_ffn2_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size],
dtype=self._dtype,
)
def load_experts_weight(self, state_dict: dict, def load_experts_weight(self, state_dict: dict,
ffn1_expert_weight_key: str, ffn1_expert_weight_key: str,
ffn2_expert_weight_key: str): ffn2_expert_weight_key: str):

View File

@@ -16,9 +16,10 @@
import triton import triton
import triton.language as tl import triton.language as tl
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import paddle_use_triton_v2
@triton.jit @paddle_use_triton_v2()
def fused_moe_kernel_paddle( def fused_moe_kernel_paddle(
a_ptr, a_ptr,
b_ptr, b_ptr,
@@ -31,22 +32,22 @@ def fused_moe_kernel_paddle(
num_tokens_post_padded_ptr, num_tokens_post_padded_ptr,
# Matrix dimensions # Matrix dimensions
N, max_possible_num_post_padded,
K,
num_tokens_post_padded,
num_valid_tokens, num_valid_tokens,
stride_am, N: tl.constexpr,
stride_ak, K: tl.constexpr,
stride_be, stride_am: tl.constexpr,
stride_bk, stride_ak: tl.constexpr,
stride_bn, stride_be: tl.constexpr,
stride_cm, stride_bk: tl.constexpr,
stride_cn, stride_bn: tl.constexpr,
stride_asm, stride_cm: tl.constexpr,
stride_ask, stride_cn: tl.constexpr,
stride_bse, stride_asm: tl.constexpr,
stride_bsk, stride_ask: tl.constexpr,
stride_bsn, stride_bse: tl.constexpr,
stride_bsk: tl.constexpr,
stride_bsn: tl.constexpr,
# Block size for block-wise fp8 quantization # Block size for block-wise fp8 quantization
group_n: tl.constexpr, group_n: tl.constexpr,
group_k: tl.constexpr, group_k: tl.constexpr,
@@ -87,7 +88,7 @@ def fused_moe_kernel_paddle(
multiplication across different blocks processed by the same expert. multiplication across different blocks processed by the same expert.
""" """
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M) num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group group_id = pid // num_pid_in_group

View File

@@ -14,10 +14,15 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Callable, Dict, Optional
import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.config import FDConfig
from .utils import get_tensor from .utils import get_tensor
@@ -28,16 +33,16 @@ class RMSNorm(nn.Layer):
def __init__( def __init__(
self, self,
fd_config, fd_config: FDConfig,
hidden_size, hidden_size: int,
eps=1e-5, eps: float = 1e-5,
prefix="", prefix: str = "",
linear_bias=None, linear_bias: paddle.Tensor = None,
quant_scale=None, quant_scale: float = None,
begin_norm_axis=1, begin_norm_axis: int = 1,
): ) -> None:
""" """
Initializes the normalization layer. Initializes the RMSNormalization layer.
Args: Args:
fd_config (FDConfig): Arguments related to inference, containing fd_config (FDConfig): Arguments related to inference, containing
@@ -45,33 +50,33 @@ class RMSNorm(nn.Layer):
num_attention_heads, and ffn_hidden_size. num_attention_heads, and ffn_hidden_size.
hidden_size (int) : size of hidden state. hidden_size (int) : size of hidden state.
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5. eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
weight_key (str): Key name of weight in the pdparams state dict. Defaults to None, means no weight. prefix(str,optional):The name of current layer. Defaults to "".
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias. linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None. quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
Raises: Raises:
NotImplementedError: If the specified norm_type is not supported. NotImplementedError: If the specified norm_type is not supported.
""" """
super().__init__() super().__init__()
self.fd_config = fd_config self.fd_config = fd_config
self.prefix = prefix self.prefix: str = prefix
self.hidden_size = hidden_size self.hidden_size: int = hidden_size
if len(prefix) == 0: if len(prefix) == 0:
self.weight_key = None self.weight_key: Optional[str] = None
else: else:
self.weight_key = f"{prefix}.weight" self.weight_key: Optional[str] = f"{prefix}.weight"
self.with_weight = self.weight_key is not None self.with_weight: bool = self.weight_key is not None
self.eps = eps self.eps: float = eps
self.norm_func = fused_rms_norm self.norm_func: Callable = fused_rms_norm
self.linear_bias = linear_bias self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.quant_scale = quant_scale self.quant_scale: Optional[float] = quant_scale
self._dtype = self._helper.get_default_dtype() self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype = self._dtype self._norm_weight_dtype: str = self._dtype
self.begin_norm_axis = begin_norm_axis self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 self.begin_norm_axis: int = begin_norm_axis
self.begin_norm_axis = begin_norm_axis
self.init_weight() self.init_weight()
@@ -88,7 +93,8 @@ class RMSNorm(nn.Layer):
dtype=self._norm_weight_dtype, dtype=self._norm_weight_dtype,
) )
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
""" """
Load the checkpoint state dictionary into the layer. Load the checkpoint state dictionary into the layer.
@@ -102,7 +108,10 @@ class RMSNorm(nn.Layer):
self._norm_weight_dtype) self._norm_weight_dtype)
self.ln_weight.set_value(weight_tensor) self.ln_weight.set_value(weight_tensor)
def forward(self, x, residual_input=None): def forward(
self,
x,
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
""" """
Defines the forward computation of the layer. Defines the forward computation of the layer.
@@ -140,18 +149,18 @@ class RMSNorm(nn.Layer):
class LayerNorm(nn.Layer): class LayerNorm(nn.Layer):
""" """
Normalization layer. Initializes the LayerNormalization layer
""" """
def __init__( def __init__(
self, self,
fd_config, fd_config: FDConfig,
hidden_size, hidden_size: int,
eps=1e-5, eps: float = 1e-5,
prefix="", prefix="",
linear_bias=None, linear_bias: paddle.Tensor = None,
quant_scale=None, quant_scale: float = None,
with_bias=False, with_bias: bool = False,
): ):
""" """
Initializes the normalization layer. Initializes the normalization layer.
@@ -160,35 +169,37 @@ class LayerNorm(nn.Layer):
fd_config (FDConfig): Arguments related to inference, containing fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim, attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size. num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
hidden_size (int) : size of hidden state. hidden_size (int) : size of hidden state.
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5. eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None. linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
with_bias (bool):Whether to include bias or not. Defaults to False.
Raises: Raises:
NotImplementedError: If the specified norm_type is not supported. NotImplementedError: If the specified norm_type is not supported.
""" """
super().__init__() super().__init__()
self.fd_config = fd_config self.fd_config = fd_config
self.prefix = prefix self.prefix: str = prefix
self.hidden_size = hidden_size self.hidden_size: int = hidden_size
if len(prefix) == 0: if len(prefix) == 0:
self.weight_key = None self.weight_key: Optional[str] = None
else: else:
self.weight_key = f"{prefix}.weight" self.weight_key: Optional[str] = f"{prefix}.weight"
self.with_weight = self.weight_key is not None self.with_weight: bool = self.weight_key is not None
self.bias_key = f"{prefix}.bias" self.bias_key: str = f"{prefix}.bias"
self.with_bias = with_bias self.with_bias: bool = with_bias
self.eps = eps self.eps: float = eps
self.quant_scale: float = quant_scale
self.norm_func: Callable = fused_layer_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = "float32"
self.norm_func = fused_layer_norm self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.linear_bias = linear_bias self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self._dtype = self._helper.get_default_dtype() self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self._norm_weight_dtype = "float32"
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.init_weight() self.init_weight()
@@ -212,7 +223,8 @@ class LayerNorm(nn.Layer):
dtype=self._norm_weight_dtype, dtype=self._norm_weight_dtype,
) )
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
""" """
Load the checkpoint state dictionary into the layer. Load the checkpoint state dictionary into the layer.
@@ -233,7 +245,10 @@ class LayerNorm(nn.Layer):
self._norm_weight_dtype) self._norm_weight_dtype)
self.ln_bias.set_value(bias_tensor) self.ln_bias.set_value(bias_tensor)
def forward(self, x, residual_input=None): def forward(
self,
x,
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
""" """
Defines the forward computation of the layer. Defines the forward computation of the layer.
@@ -259,7 +274,7 @@ class LayerNorm(nn.Layer):
begin_norm_axis=1, begin_norm_axis=1,
bias=self.linear_bias, bias=self.linear_bias,
residual=residual_input, residual=residual_input,
quant_scale=-1, quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type, quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound, quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound, quant_min_bound=self.quant_min_bound,

View File

@@ -132,18 +132,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer): def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse() layer.linear_weight_shape.reverse()
if self.quant_config.name() == "wint4": if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2 layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8" layer.weight_dtype = "int8"
linear_weight_scale_shape = [layer.embed_dim]
if hasattr(layer, "linear_weight_shape"):
if isinstance(layer.linear_weight_shape, list):
layer_weight_shape = layer.linear_weight_shape
linear_weight_scale_shape = layer_weight_shape[:1]
if self.quant_config.name() == "wint4":
linear_weight_scale_shape[0] *= 2
layer.linear_weight_scale = layer.create_parameter( layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape, shape=linear_weight_scale_shape,
dtype=layer._dtype, dtype=layer._dtype,
@@ -195,6 +191,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
weight_scale.astype(paddle.get_default_dtype())) weight_scale.astype(paddle.get_default_dtype()))
def process_loaded_weights(self, layer, weight) -> None: def process_loaded_weights(self, layer, weight) -> None:
quanted_weight_tensor, weight_scale_tensor = weight_quantize( quanted_weight_tensor, weight_scale_tensor = weight_quantize(
weight, weight,
algo=self.quant_config.algo, algo=self.quant_config.algo,

View File

@@ -14,13 +14,18 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Optional import math
from typing import Optional, Tuple
import paddle import paddle
import paddle.nn as nn
from fastdeploy.config import ModelConfig from fastdeploy.config import ModelConfig
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
from .utils import CpuGuard from .utils import CpuGuard
@@ -99,20 +104,164 @@ class QwenRotaryEmbedding:
return rot_emb return rot_emb
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_find_correction_dim(num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
"""
"""
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))
def yarn_find_correction_range(low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
"""
"""
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def yarn_linear_ramp_mask(min, max, dim):
"""
"""
if min == max:
max += 0.001 # Prevent singularity
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max -
min)
ramp_func = paddle.clip(linear_func, 0, 1)
return ramp_func
class DeepseekScalingRotaryEmbedding(nn.Layer):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
Args:
rotary_dim(int): Dimension of rotary embeddings (head dimension)
max_position_embeddings(int): Original training context length
base(float): Base value used to compute the inverse frequencies.
scaling_factor(float): Context extension scaling ratio (target_len / original_len)
extrapolation_factor(float): Weight for extrapolated frequencies (default=1)
attn_factor(float): Attention magnitude scaling factor (default=1)
beta_fast(int): High-frequency correction cutoff (default=32)
beta_slow(int): Low-frequency correction cutoff (default=1)
mscale(float): Primary magnitude scaling factor (default=1)
mscale_all_dim(float): Alternate magnitude scaling factor (default=0)
"""
def __init__(
self,
rotary_dim: int,
max_position_embeddings: int,
base: int,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
super().__init__()
self._dtype = paddle.get_default_dtype()
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: paddle.Tensor
self.register_buffer("cos_sin_cache", cache, persistable=True)
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
pos_freqs = self.base**(
paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> paddle.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = paddle.arange(self.max_position_embeddings * self.scaling_factor,
dtype=paddle.float32)
freqs = paddle.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = paddle.concat((cos, sin), axis=-1)
return cache.cast(self._dtype)
def forward(
self,
position_ids: paddle.Tensor,
query: paddle.Tensor,
key: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
"""
# In-place operations that update the query and key tensors.
fused_rotary_position_encoding(query, key, position_ids,
self.cos_sin_cache, self.rotary_dim,
False)
return query, key
def get_rope_impl( def get_rope_impl(
rotary_dim: int, rotary_dim: int,
base: 10000.0, base: 10000.0,
position_ids, position_ids: paddle.Tensor,
model_config: Optional[ModelConfig] = None, model_config: Optional[ModelConfig] = None,
partial_rotary_factor=1, partial_rotary_factor=1,
): ) -> paddle.Tensor:
""" """
The real implementation of get_rope The real implementation of get_rope
""" """
architecture = model_config.architectures[0] architecture = model_config.architectures[0]
if model_config is not None and model_config is None or architecture.startswith( if model_config is None or architecture.startswith("Qwen"):
"Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
partial_rotary_factor) partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids) rotary_emb = rotary_emb_layer(position_ids)
@@ -126,10 +275,10 @@ def get_rope_impl(
def get_rope_xpu( def get_rope_xpu(
rotary_dim: int, rotary_dim: int,
base: 10000.0, base: 10000.0,
position_ids, position_ids: paddle.Tensor,
model_config: ModelConfig, model_config: Optional[ModelConfig] = None,
partial_rotary_factor=1, partial_rotary_factor=1,
): ) -> paddle.Tensor:
""" """
In XPU, cos and sin compute must be done on cpu In XPU, cos and sin compute must be done on cpu
""" """
@@ -143,12 +292,27 @@ def get_rope_xpu(
def get_rope( def get_rope(
rotary_dim: int, rotary_dim: int,
base: 10000.0, base: 10000.0,
position_ids, position_ids: paddle.Tensor,
model_config: ModelConfig, model_config: Optional[ModelConfig] = None,
partial_rotary_factor=1, partial_rotary_factor: int = 1,
): ) -> paddle.Tensor:
""" """
The warpper of get_rope Pre-calculate rotary position embedding for position_ids.
Args:
rotary_dim (int):
Dimension of rotary embeddings (head dimension)
base (float, optional):
Base value used to compute the inverse frequencies.
Default: 10000.0.
position_ids (paddle.Tensor):
Tensor containing position indices of input tokens.
model_config (Optional[ModelConfig]):
Model configuration object containing architecture information.
If provided, determines RoPE implementation based on model architecture.
partial_rotary_factor (int, optional):
Factor controlling partial rotary application.
Default: 1 (apply to all dimensions).
""" """
if current_platform.is_xpu(): if current_platform.is_xpu():
return get_rope_xpu(rotary_dim, base, position_ids, model_config, return get_rope_xpu(rotary_dim, base, position_ids, model_config,
@@ -255,7 +419,24 @@ def get_rope_3d(
paritial_rotary_factor: 1, paritial_rotary_factor: 1,
max_position: 131072, max_position: 131072,
freq_allocation: 2, freq_allocation: 2,
): ) -> paddle.Tensor:
"""
Pre-calculate rotary position embedding for position_ids.
Args:
rotary_dim (int):
Dimension of rotary embeddings (head dimension)
base (float, optional):
Base value used to compute the inverse frequencies.
Default: 10000.0.
position_ids (paddle.Tensor):
Tensor containing position indices of input tokens.
partial_rotary_factor (int, optional):
Factor controlling partial rotary application.
Default: 1 (apply to all dimensions).
max_position: Maximum position index to precompute.
freq_allocation: Number of rotary dimensions allocated to temporal axis
"""
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base, rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base,
paritial_rotary_factor, paritial_rotary_factor,
max_position, max_position,

View File

@@ -0,0 +1,289 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import os
import paddle
import paddle.distributed as dist
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.models.tp_utils import \
check_tensor_parallel_prerequisites
from fastdeploy.platforms import current_platform
def load_ep_checkpoint(model_path: str,
config: ModelConfig,
return_numpy: bool = False):
"""
load ep checkpoint
"""
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
num_local_ffn_keys = []
for i in range(config.moe_layer_start_index, config.num_layers):
for j in range(
config.num_experts_start_offset,
config.num_experts_start_offset + config.num_experts_per_rank,
):
ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
ffn2_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
ffn2_quant_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
ffn2_scale_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
num_local_ffn_keys.append(ffn1_key)
num_local_ffn_keys.append(ffn2_key)
num_local_ffn_keys.append(ffn1_quant_key)
num_local_ffn_keys.append(ffn2_quant_key)
num_local_ffn_keys.append(ffn1_scale_key)
num_local_ffn_keys.append(ffn2_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:
filtered_map[k] = weight_list[k]
state_dict = {}
# Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths,
desc="Loading safetensor files",
unit="file"):
with safe_open(os.path.join(model_path, safetensor_path),
framework="np",
device="cpu") as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(
paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
for name in f.keys():
param = f.get_tensor(name)
yield name, param
def fastsafetensors_weights_iterator(safetensor_list: list[str], ):
"""
Return an iterator over tensors on GPU from a given safetensor_list.
"""
world_size = dist.get_world_size()
if world_size > 1:
pg = dist.get_group()
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
else:
pg = SingleGroup()
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda(
) else "cpu"
safetensor_files_sub_lists = [
safetensor_list[i:i + world_size]
for i in range(0, len(safetensor_list), world_size)
]
for st_file in tqdm(
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
):
loader = SafeTensorsFileLoader(pg,
device,
nogds=True,
debug_log=False,
framework="paddle")
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def load_pre_sharded_checkpoint(model_path: str,
local_rank: int,
use_fastsafetensor: bool = False):
"""
load_pre_sharded_checkpoint
"""
state_dict = {}
_, safetensor_files = get_all_safetensors(
os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def get_all_safetensors(model_path: str):
"""
get_all_safetensors
"""
safe_model_path = os.path.join(model_path, "model.safetensors")
if os.path.exists(safe_model_path):
safetensor_list = [safe_model_path]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, safetensor_list
else:
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
return key_name_list, safetensor_list
def load_tp_checkpoint_v1(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
use_fastsafetensor: bool = True,
):
"""
load_tp_checkpoint_v1
"""
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
else:
weights_iterator = safetensors_weights_iterator(safetensor_files)
tensor_parallel_filtered_map = {}
check_tensor_parallel_prerequisites(
fd_config,
cls,
tensor_parallel_filtered_map,
safetensor_keys,
)
need_tp = True if tensor_parallel_filtered_map else False
state_dict = {}
for key, weight in weights_iterator:
paddle.device.cuda.synchronize()
if need_tp and key in tensor_parallel_filtered_map:
action = tensor_parallel_filtered_map.pop(key)
tensor = action(weight).clone()
else:
tensor = weight.clone()
state_dict[key] = tensor
weight.value().get_tensor()._clear()
return state_dict
def deal_state_dict(state_dict):
"""deal_state_dict"""
device = paddle.CUDAPinnedPlace()
for name, src in state_dict.items():
if src._is_initialized() and not isinstance(src.place,
paddle.CUDAPinnedPlace):
dst = src._copy_to(device, True)
dst_tensor = dst.value().get_tensor()
src_tensor = src.value().get_tensor()
src_tensor._clear()
src_tensor._share_data_with(dst_tensor)
def load_composite_checkpoint(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
return_numpy=True,
):
"""
# This method supports loading model weights under three parallelism strategies:
# 1. Expert Parallel (EP)
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
if fd_config.parallel_config.use_ep:
state_dict = load_ep_checkpoint(model_path,
fd_config.model_config,
return_numpy=True)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank")
and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_degree != len(
rank_dirs):
raise ValueError(
f"Your model only supports loading with tp{len(rank_dirs)}"
)
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
else:
if fd_config.load_config.use_fastsafetensor and (
current_platform.available()
and current_platform.is_cuda()):
state_dict = load_tp_checkpoint_v1(model_path,
cls,
fd_config,
use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
state_dict = load_tp_checkpoint(model_path,
cls,
fd_config.model_config,
return_numpy=return_numpy)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict

View File

@@ -20,6 +20,10 @@ import paddle
from paddle import nn from paddle import nn
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import \
load_composite_checkpoint
from fastdeploy.model_executor.models.deepseek_v3 import \
DeepSeekV3PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_moe import \ from fastdeploy.model_executor.models.ernie4_5_moe import \
Ernie4_5_PretrainedModel Ernie4_5_PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_mtp import \ from fastdeploy.model_executor.models.ernie4_5_mtp import \
@@ -28,7 +32,7 @@ from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
from fastdeploy.model_executor.models.utils import load_checkpoint from fastdeploy.platforms import current_platform
MODEL_CLASSES = { MODEL_CLASSES = {
"Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel, "Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel,
@@ -36,7 +40,8 @@ MODEL_CLASSES = {
"Qwen2ForCausalLM": Qwen2PretrainedModel, "Qwen2ForCausalLM": Qwen2PretrainedModel,
"Qwen3ForCausalLM": Qwen3PretrainedModel, "Qwen3ForCausalLM": Qwen3PretrainedModel,
"Qwen3MoeForCausalLM": Qwen3MoePretrainedModel, "Qwen3MoeForCausalLM": Qwen3MoePretrainedModel,
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel "Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel,
"DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel,
} }
@@ -73,23 +78,38 @@ class DefaultModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
pass pass
def clean_memory_fragments(self, state_dict: dict) -> None:
"""clean_memory_fragments"""
if current_platform.is_cuda():
if state_dict:
for k, v in state_dict.items():
if isinstance(v, paddle.Tensor):
v.value().get_tensor()._clear()
paddle.device.cuda.empty_cache()
paddle.device.cuda.synchronize()
def load_model(self, fd_config: FDConfig) -> nn.Layer: def load_model(self, fd_config: FDConfig) -> nn.Layer:
context = paddle.LazyGuard() context = paddle.LazyGuard()
architectures = fd_config.model_config.architectures[0] architectures = fd_config.model_config.architectures[0]
# TODO(gongshaotian): Now, only support safetensor # TODO(gongshaotian): Now, only support safetensor
model_class = MODEL_CLASSES[architectures] model_class = MODEL_CLASSES[architectures]
state_dict = load_checkpoint(
fd_config.parallel_config.model_name_or_path,
model_class,
fd_config.model_config,
return_numpy=True)
with context: with context:
model_cls = ModelRegistry.get_class(architectures) model_cls = ModelRegistry.get_class(architectures)
model = model_cls(fd_config) model = model_cls(fd_config)
model.eval() model.eval()
model.set_state_dict(state_dict)
# RL model not need set_state_dict
if fd_config.load_config.dynamic_load_weight:
return model
state_dict = load_composite_checkpoint(
fd_config.parallel_config.model_name_or_path,
model_class,
fd_config,
return_numpy=True,
)
model.set_state_dict(state_dict)
self.clean_memory_fragments(state_dict)
return model return model

View File

@@ -20,15 +20,6 @@ from pathlib import Path
from .model_base import ModelForCasualLM, ModelRegistry from .model_base import ModelForCasualLM, ModelRegistry
inference_runner_supported_models = [
"Ernie4_5_MoeForCausalLM",
"Ernie4_5_MTPForCausalLM",
"Qwen2ForCausalLM",
"Qwen3MoeForCausalLM",
"Ernie4_5_ForCausalLM",
"Qwen3ForCausalLM",
]
def _find_py_files(root_dir): def _find_py_files(root_dir):
root_path = Path(root_dir) root_path = Path(root_dir)
@@ -44,22 +35,23 @@ def _find_py_files(root_dir):
return py_files return py_files
def auto_models_registry(): def auto_models_registry(dir_path,
register_path="fastdeploy.model_executor.models",
suffix=""):
""" """
auto registry all models in this folder auto registry all models in this folder
""" """
for module_file in _find_py_files(os.path.dirname(__file__)): for module_file in _find_py_files(dir_path):
try: try:
module = importlib.import_module( module = importlib.import_module(f'{register_path}.{module_file}')
f'fastdeploy.model_executor.models.{module_file}')
for attr_name in dir(module): for attr_name in dir(module):
attr = getattr(module, attr_name) attr = getattr(module, attr_name)
if inspect.isclass(attr) and issubclass( if inspect.isclass(attr) and issubclass(
attr, attr,
ModelForCasualLM) and attr is not ModelForCasualLM: ModelForCasualLM) and attr is not ModelForCasualLM:
ModelRegistry.register(attr) ModelRegistry.register(attr, suffix=suffix)
except ImportError: except ImportError:
raise ImportError(f"{module_file=} import error") raise ImportError(f"{module_file=} import error")
auto_models_registry() auto_models_registry(os.path.dirname(__file__))

View File

@@ -0,0 +1,763 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import math
from functools import partial
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
ReplicatedLinear, RowParallelLinear)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.rotary_embedding import \
DeepseekScalingRotaryEmbedding
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.platforms import current_platform
from fastdeploy.worker.forward_meta import ForwardMeta
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import \
get_position_ids_and_mask_encoder_batch
class DeepSeekV3MLP(nn.Layer):
"""
DeepSeekV3MLP, for Dense FFN and Shared Experts Layer.
"""
def __init__(
self,
fd_config: FDConfig,
intermediate_size: int,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
reduce_results=reduce_results,
)
self.act_fn = SiluAndMul(
fd_config=fd_config,
bias=None,
act_method=fd_config.model_config.hidden_act,
)
def load_state_dict(self, state_dict):
"""
"""
self.gate_up_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
"""
"""
gate_up_out = self.gate_up_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
class DeepSeekV3MoE(nn.Layer):
"""
DeepSeekV3MoE, for MoE Layer.
"""
def __init__(self, fd_config: FDConfig, layer_id: int,
prefix: str) -> None:
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.gate.e_score_correction_bias",
"ffn1_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"ffn2_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
self.fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.deepseekv3.
moe_intermediate_size,
num_experts=fd_config.model_config.deepseekv3.n_routed_experts,
top_k=fd_config.model_config.deepseekv3.num_experts_per_tok,
topk_method=fd_config.model_config.deepseekv3.topk_method,
topk_group=fd_config.model_config.deepseekv3.topk_group,
n_group=fd_config.model_config.deepseekv3.n_group,
routed_scaling_factor=fd_config.model_config.deepseekv3.
routed_scaling_factor,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
self.num_shared_experts = fd_config.model_config.deepseekv3.n_shared_experts
shared_experts_intermediate_size = (
self.num_shared_experts *
fd_config.model_config.deepseekv3.moe_intermediate_size)
self.shared_experts = DeepSeekV3MLP(
fd_config=fd_config,
intermediate_size=shared_experts_intermediate_size,
prefix=f"{prefix}.shared_experts",
reduce_results=False,
)
def load_state_dict(self, state_dict):
"""
"""
self.fused_moe.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
"""
"""
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.fused_moe(hidden_states)
moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tp_size > 1:
tensor_model_parallel_all_reduce(moe_out)
return moe_out
class DeepseekV3MLAAttention(nn.Layer):
"""
DeepseekV3MLAAttention
"""
def __init__(self,
fd_config: FDConfig,
layer_id: int,
prefix: str = "") -> None:
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
self.hidden_size = fd_config.model_config.hidden_size
self.num_attention_heads = fd_config.model_config.num_attention_heads
self.num_attention_heads_tp = self.num_attention_heads // self.tp_size
# MLA
self.qk_nope_head_dim = fd_config.model_config.deepseekv3.qk_nope_head_dim
self.qk_rope_head_dim = fd_config.model_config.deepseekv3.qk_rope_head_dim
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.v_head_dim = fd_config.model_config.deepseekv3.v_head_dim
self.q_lora_rank = fd_config.model_config.deepseekv3.q_lora_rank
self.kv_lora_rank = fd_config.model_config.deepseekv3.kv_lora_rank
self.attn_softmax_scale = self.qk_head_dim**-0.5
self.rope_theta = fd_config.model_config.rope_theta
self.rms_norm_eps = fd_config.model_config.rms_norm_eps
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(fd_config=fd_config,
prefix=f"{prefix}.q_a_proj",
input_size=self.hidden_size,
output_size=self.q_lora_rank,
with_bias=False)
self.q_a_layernorm = RMSNorm(fd_config,
hidden_size=self.q_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.q_a_layernorm")
self.q_b_proj = ColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.q_b_proj",
input_size=self.q_lora_rank,
output_size=self.num_attention_heads * self.qk_head_dim,
with_bias=False,
)
else:
assert (self.q_lora_rank is not None
), "self.q_lora_rank is None, Please Check your config."
# 不切TP,跑 W4A16 Gemm
self.kv_a_proj_with_mqa = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.kv_lora_rank + self.qk_rope_head_dim,
with_bias=False)
self.kv_a_layernorm = RMSNorm(fd_config,
hidden_size=self.kv_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.kv_a_layernorm")
self.kv_b_proj = ColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_b_proj",
input_size=self.kv_lora_rank,
output_size=self.num_attention_heads *
(self.qk_nope_head_dim + self.v_head_dim),
with_bias=False,
)
self.o_proj = RowParallelLinear(fd_config,
prefix=f"{prefix}.o_proj",
input_size=self.num_attention_heads *
self.v_head_dim,
output_size=self.hidden_size,
with_bias=False)
self.kv_b_proj_bmm = KVBatchLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_b_proj",
kv_lora_rank=self.kv_lora_rank,
num_attention_heads=self.num_attention_heads,
qk_nope_head_dim=self.qk_nope_head_dim,
v_head_dim=self.v_head_dim)
self.rope_scaling = fd_config.model_config.deepseekv3.rope_scaling
if self.rope_scaling:
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
scaling_factor = self.rope_scaling["factor"]
mscale = self.yarn_get_mscale(scaling_factor,
float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
rope_scaling_kwargs = {
key: self.rope_scaling[key]
for key in [
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
] if key in self.rope_scaling
}
self.rope_scaling_factor = self.rope_scaling["factor"]
self.rope_scaling_original_max_position_embeddings = self.rope_scaling[
"original_max_position_embeddings"]
self.rotary_emb = DeepseekScalingRotaryEmbedding(
self.qk_rope_head_dim,
max_position_embeddings=self.
rope_scaling_original_max_position_embeddings,
base=self.rope_theta,
scaling_factor=self.rope_scaling_factor,
**rope_scaling_kwargs,
)
self.mla_attn = Attention(
fd_config=fd_config,
layer_id=layer_id,
prefix=prefix,
use_neox_rotary_style=False,
)
self.prefix = prefix
@staticmethod
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
layernorm_out = hidden_states
fmha_out = paddle.zeros(shape=[
layernorm_out.shape[0],
self.num_attention_heads_tp * self.v_head_dim
],
dtype=layernorm_out.dtype)
decode_stage = forward_meta.is_decode_batch
prefill_stage = not (forward_meta.is_decode_batch)
if prefill_stage:
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split(
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
key_value = self.kv_b_proj(compressed_kv)
key_value = key_value.reshape([
-1, self.num_attention_heads_tp,
self.qk_nope_head_dim + self.v_head_dim
])
key_nope, value = key_value.split(
[self.qk_nope_head_dim, self.v_head_dim], axis=-1)
query[..., self.qk_nope_head_dim:] = query_pe
key = paddle.empty_like(query)
key[..., :self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim:] = key_pe
value = paddle.nn.functional.pad(
value, [0, self.qk_head_dim - self.v_head_dim], value=0)
fmha_out_prefill = self.mla_attn(q=query,
k=key,
v=value,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta)
fmha_out_prefill = fmha_out_prefill.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, :self.v_head_dim]
fmha_out_prefill = fmha_out_prefill.reshape(
[-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(
fmha_out_prefill.dtype)
fmha_out = fmha_out + fmha_out_prefill
if decode_stage:
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
ln_out_or_q_c = query
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split(
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query = self.q_b_proj(ln_out_or_q_c)
query = query.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]),
proj_type='k').transpose([1, 0, 2])
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
q_input = q_input.reshape([
-1,
self.num_attention_heads_tp *
(self.kv_lora_rank + self.qk_rope_head_dim),
])
fmha_out_decode = self.mla_attn(q=q_input,
k=None,
v=None,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta)
fmha_out_decode = fmha_out_decode.reshape(
[-1, self.num_attention_heads_tp,
self.kv_lora_rank]).transpose([1, 0, 2])
fmha_out_decode = (self.kv_b_proj_bmm(
fmha_out_decode, proj_type='v').transpose([1, 0, 2]).reshape(
[-1, self.num_attention_heads_tp * self.v_head_dim]))
fmha_out = fmha_out + fmha_out_decode
output = self.o_proj(fmha_out)
return output
def load_state_dict(self, state_dict):
"""
"""
self.q_a_proj.load_state_dict(state_dict)
self.q_a_layernorm.load_state_dict(state_dict)
self.kv_a_proj_with_mqa.load_state_dict(state_dict)
self.kv_a_layernorm.load_state_dict(state_dict)
self.q_b_proj.load_state_dict(state_dict)
self.kv_b_proj_bmm.load_state_dict(state_dict)
self.kv_b_proj.load_state_dict(state_dict)
# NOTE(Ryan):Make sure kv_b_proj_bmm loaded before kv_b_proj,
# The same weight key will be poped after kv_b_proj.
self.o_proj.load_state_dict(state_dict)
class DeepSeekV3DecoderLayer(nn.Layer):
"""
DeepSeekV3DecoderLayer
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
) -> None:
super().__init__()
layer_id = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV3MLAAttention(
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}.self_attn",
)
if (fd_config.model_config.deepseekv3.n_routed_experts is not None
and layer_id
>= fd_config.model_config.deepseekv3.first_k_dense_replace):
self.mlp = DeepSeekV3MoE(
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = DeepSeekV3MLP(
fd_config=fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
)
self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
)
def load_state_dict(self, state_dict):
"""
"""
self.self_attn.load_state_dict(state_dict)
self.mlp.load_state_dict(state_dict)
self.input_layernorm.load_state_dict(state_dict)
self.post_attention_layernorm.load_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
residual: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(forward_meta, hidden_states,
position_ids, mask_encoder_batch)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class DeepSeekV3Model(nn.Layer):
"""
DeepSeekV3Model
"""
def __init__(
self,
fd_config: FDConfig = None,
):
"""
Initializer for the DeepSeekV3Model class.
"""
super().__init__()
self.num_layers = fd_config.model_config.num_layers
fd_config.model_config.prefix_name = "deepseek_v3"
self.embeddings = VocabParallelEmbedding(
fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype(),
prefix="deepseek_v3.embed_tokens",
)
self.decoder_layers = nn.LayerList([
DeepSeekV3DecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
for i in range(self.num_layers)
])
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix="deepseek_v3.norm",
)
def pre_process(self, forward_meta):
"""
"""
seq_lens_encoder = forward_meta.seq_lens_encoder
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
position_ids_shape = paddle.sum(seq_lens_this_time)
position_ids = paddle.empty(shape=position_ids_shape,
dtype=seq_lens_encoder.dtype)
mask_encoder_batch = paddle.empty(
shape=position_ids_shape,
dtype=seq_lens_encoder.dtype).unsqueeze(1)
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
position_ids,
mask_encoder_batch)
return position_ids, mask_encoder_batch
def load_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
"""
self.embeddings.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.decoder_layers[i].load_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
"""
"""
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.decoder_layers[i](
forward_meta, hidden_states, residual, position_ids,
mask_encoder_batch)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
return out
class DeepseekV3ForCausalLM(ModelForCasualLM):
"""
DeepseekV3ForCausalLM
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super().__init__(fd_config)
self.model = DeepSeekV3Model(fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = ParallelLMHead(
fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
@classmethod
def name(cls):
"""
"""
return "DeepseekV3ForCausalLM"
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
"""
self.model.load_state_dict(state_dict)
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
"""
"""
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits[:, self.ori_vocab_size:] = -float("inf")
return logits
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
"""
"""
hidden_states = self.model(ids_remove_padding, forward_meta)
return hidden_states
class DeepSeekV3PretrainedModel(PretrainedModel):
"""
DeepSeekV3PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
logger.info("DeepseekV3 inference model _get_tensor_parallel_mappings")
from paddleformers.transformers.conversion_utils import \
split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)
def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn,
is_column=False),
}
# Self Attention Layer which are need TP.
base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.self_attn.q_b_proj.weight_scale_inv"] = partial(
fn, is_column=True)
base_actions[
"layers.0.self_attn.kv_b_proj.weight_scale_inv"] = partial(
fn, is_column=True)
# MLP Layer
base_actions["layers.0.mlp.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(
fn, is_column=False)
# Moe Layer
for expert_idx in range(config.n_routed_experts):
base_actions[
f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(
fn, is_column=True)
base_actions[
f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions[
f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(
fn, is_column=False)
# Shared Expert Layer
base_actions[
"layers.0.mlp.shared_experts.up_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.mlp.shared_experts.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.mlp.shared_experts.down_proj.weight"] = partial(
fn, is_column=False)
# MTP parts
base_actions["layers.61.embed_tokens.weight"] = partial(
fn, is_column=False)
base_actions["layers.61.eh_proj.weight"] = partial(fn,
is_column=True)
base_actions["layers.61.shared_head.head.weight"] = partial(
fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_layers)
return mappings

View File

@@ -37,291 +37,13 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import \
LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta
class Ernie4_5_PretrainedModel(PretrainedModel):
"""
Ernie4_5_PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""
logger.info("erine inference model _get_tensor_parallel_mappings")
from paddleformers.transformers.conversion_utils import \
split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)
def gqa_qkv_split_func(
weight,
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
head_dim,
):
def get_shape(tensor):
return (tensor.get_shape()
if hasattr(tensor, "get_shape") else tensor.shape)
def slice_tensor(tensor, start, end):
shape = get_shape(tensor)
if len(shape) == 1:
return tensor[start:end]
else:
return tensor[..., start:end]
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(weight, 0, q_end)
k = slice_tensor(weight, q_end, k_end)
v = slice_tensor(weight, k_end, v_end)
def split_tensor(tensor, degree):
shape = get_shape(tensor)
size = shape[-1]
block_size = size // degree
if hasattr(tensor, "get_shape"):
return [
slice_tensor(tensor, i * block_size,
(i + 1) * block_size)
for i in range(degree)
]
else:
return np.split(tensor, degree, axis=-1)
q_list = split_tensor(q, tensor_parallel_degree)
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
if tensor_parallel_rank is None:
return [
np.concatenate([q_i, k_i, v_i], axis=-1)
for q_i, k_i, v_i in zip(q_list, k_list, v_list)
]
else:
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=-1,
)
def gqa_qkv_merge_func(weight_list, num_attention_heads,
num_key_value_heads, head_dim):
tensor_parallel_degree = len(weight_list)
num_attention_heads = num_attention_heads // tensor_parallel_degree
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
def get_shape(tensor):
return (tensor.get_shape()
if hasattr(tensor, "get_shape") else tensor.shape)
def slice_tensor(tensor, start, end):
if len(get_shape(tensor)) == 1:
return tensor[start:end]
else:
return tensor[..., start:end]
q_list, k_list, v_list = [], [], []
for weight in weight_list:
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(weight, 0, q_end)
k = slice_tensor(weight, q_end, k_end)
v = slice_tensor(weight, k_end, v_end)
q_list.append(q)
k_list.append(k)
v_list.append(v)
merged = q_list + k_list + v_list
if is_paddle_tensor:
tensor = paddle.concat(merged, axis=-1)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor
else:
return np.concatenate(merged, axis=-1)
if (config.num_key_value_heads is not None
and config.num_key_value_heads != config.num_attention_heads):
if is_split:
qkv_fn = partial(
gqa_qkv_split_func,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
else:
qkv_fn = partial(
gqa_qkv_merge_func,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
else:
qkv_fn = partial(fn, is_column=True)
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
moe_num_shared_experts,
moe_layer_start_index):
final_actions = {}
base_model_prefix = "ernie"
base_actions = {
"lm_head.weight":
partial(fn, is_column=True),
# "eh_proj.weight": partial(fn, is_column=True),
f"{base_model_prefix}.embed_tokens.weight":
partial(fn, is_column=False),
}
base_actions[
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.weight"] = qkv_fn
base_actions[
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.quant_weight"] = qkv_fn
base_actions[
f"{base_model_prefix}.layers.0.self_attn.o_proj.weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.0.self_attn.o_proj.quant_weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.0.mlp.down_proj.weight"] = (
partial(fn, is_column=False))
base_actions[
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"] = partial(
fn, is_column=False)
for expert_idx in range(moe_num_experts):
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.up_gate_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.up_gate_proj.quant_weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.down_proj.weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.down_proj.quant_weight"] = partial(
fn, is_column=False)
if moe_num_shared_experts > 0:
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.up_gate_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.down_proj.weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
fn, is_column=False, is_naive_2fuse=True)
for key, action in base_actions.items():
if (f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"
in key or
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"
in key
or f"{base_model_prefix}.layers.0.mlp.down_proj.weight"
in key or
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"
in key):
for i in range(moe_layer_start_index):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
elif f"layers.{moe_layer_start_index}.mlp.experts." in key:
for i in range(moe_layer_start_index, num_layers):
final_actions[key.replace(
f"layers.{moe_layer_start_index}.",
f"layers.{i}.")] = action
elif f"layers.{moe_layer_start_index}.mlp.shared_experts." in key:
for i in range(moe_layer_start_index, num_layers):
final_actions[key.replace(
f"layers.{moe_layer_start_index}.",
f"layers.{i}.")] = action
elif f"{base_model_prefix}.layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
moe_num_experts = 0
moe_num_shared_experts = 0
if isinstance(config.moe_num_experts, list):
moe_num_experts = sum(config.moe_num_experts)
elif isinstance(config.moe_num_experts, int):
moe_num_experts = config.moe_num_experts
if hasattr(config, 'moe_num_shared_experts'):
moe_num_shared_experts = config.moe_num_shared_experts
moe_layer_start_index = -1
if isinstance(config.moe_layer_start_index, list):
moe_layer_start_index = min(config.moe_layer_start_index)
elif isinstance(config.moe_layer_start_index, int):
moe_layer_start_index = config.moe_layer_start_index
mappings = get_tensor_parallel_split_mappings(
config.num_layers,
moe_num_experts,
moe_num_shared_experts,
moe_layer_start_index,
)
return mappings
class Ernie4_5_MLP(nn.Layer): class Ernie4_5_MLP(nn.Layer):
def __init__( def __init__(
@@ -329,6 +51,7 @@ class Ernie4_5_MLP(nn.Layer):
fd_config: FDConfig, fd_config: FDConfig,
intermediate_size: int, intermediate_size: int,
prefix: str = "", prefix: str = "",
reduce_results: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_degree self.nranks = fd_config.parallel_config.tensor_parallel_degree
@@ -345,7 +68,7 @@ class Ernie4_5_MLP(nn.Layer):
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
input_size=(intermediate_size // self.nranks), input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
with_bias=False, with_bias=False,
) )
@@ -423,8 +146,8 @@ class Ernie4_5_MoE(nn.Layer):
f"{prefix}.experts.{{}}.down_proj.code_zp", f"{prefix}.experts.{{}}.down_proj.code_zp",
} }
elif moe_quant_type == "tensor_wise_fp8" or ( elif moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and moe_quant_type == "block_wise_fp8"
fd_config.model_config.is_quantized): and fd_config.model_config.is_quantized):
weight_key_map = { weight_key_map = {
"gate_weight_key": "gate_weight_key":
f"{prefix}.gate.weight", f"{prefix}.gate.weight",
@@ -492,8 +215,6 @@ class Ernie4_5_Attention(nn.Layer):
prefix: str) -> None: prefix: str) -> None:
super().__init__() super().__init__()
nranks = fd_config.parallel_config.tensor_parallel_degree
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
@@ -502,8 +223,8 @@ class Ernie4_5_Attention(nn.Layer):
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
input_size=(fd_config.model_config.head_dim * input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads // nranks), fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
) )
self.attn = Attention( self.attn = Attention(
@@ -636,12 +357,12 @@ class Ernie4_5_Model(nn.Layer):
params_dtype=paddle.get_default_dtype(), params_dtype=paddle.get_default_dtype(),
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens")) prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"))
self.hidden_layers = [ self.hidden_layers = nn.LayerList([
Ernie4_5_DecoderLayer( Ernie4_5_DecoderLayer(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}") prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
for i in range(self.num_layers) for i in range(self.num_layers)
] ])
self.norm = RMSNorm( self.norm = RMSNorm(
fd_config, fd_config,
@@ -772,3 +493,134 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
Model Architecture Name Model Architecture Name
""" """
return "Ernie4_5_ForCausalLM" return "Ernie4_5_ForCausalLM"
class Ernie4_5_PretrainedModel(PretrainedModel):
"""
Ernie4_5_PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
weight_infos = [
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
True, tsm.GQA),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight",
False),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
True, tsm.PairFused),
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
False),
WeightMeta(".embed_tokens.weight", False),
WeightMeta("lm_head.weight", True),
# quant tensorwise
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.quant_weight",
True, tsm.GQA),
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.quant_weight",
False),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.quant_weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.quant_weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.quant_weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.quant_weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.quant_weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.quant_weight",
False),
]
@classmethod
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""
logger.info("erine inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import (
build_expanded_keys, has_prefix, split_or_merge_func_v1)
fn = split_or_merge_func_v1(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim)
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
moe_layer_start_index,
prefix_name):
base_actions = {}
weight_infos = cls.weight_infos
for (weight_name, is_column, extra) in weight_infos:
params = {
"is_column": is_column,
**({
extra.value: True
} if extra else {})
}
if "lm_head.weight" in weight_name:
key = weight_name
elif not has_prefix(prefix_name, weight_name):
key = f"{prefix_name}{weight_name}"
else:
key = weight_name
base_actions[key] = partial(fn, **params)
final_actions = {}
start_layer = (moe_layer_start_index
if moe_layer_start_index > 0 else num_layers)
final_actions = build_expanded_keys(
num_layers,
moe_num_experts,
start_layer,
base_actions,
)
return final_actions
moe_num_experts = 0
if isinstance(config.moe_num_experts, list):
moe_num_experts = sum(config.moe_num_experts)
elif isinstance(config.moe_num_experts, int):
moe_num_experts = config.moe_num_experts
moe_layer_start_index = -1
if isinstance(config.moe_layer_start_index, list):
moe_layer_start_index = min(config.moe_layer_start_index)
elif isinstance(config.moe_layer_start_index, int):
moe_layer_start_index = config.moe_layer_start_index
mappings = get_tensor_parallel_split_mappings(config.num_layers,
moe_num_experts,
moe_layer_start_index,
config.prefix_name)
return mappings

View File

@@ -265,12 +265,12 @@ class Ernie4_5_MTPModel(nn.Layer):
self.num_layers = fd_config.model_config.num_layers self.num_layers = fd_config.model_config.num_layers
self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings
self.hidden_layers = [ self.hidden_layers = nn.LayerList([
Ernie4_5_DecoderLayer( Ernie4_5_DecoderLayer(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{fd_config.model_config.prefix_name}.{i}") prefix=f"{fd_config.model_config.prefix_name}.{i}")
for i in range(self.num_layers) for i in range(self.num_layers)
] ])
self.enorm = RMSNorm( self.enorm = RMSNorm(
fd_config, fd_config,

View File

@@ -25,6 +25,8 @@ from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.moe.moe import FusedMoE
@@ -66,6 +68,7 @@ class Ernie4_5_VLMoE(nn.Layer):
prefix: str) -> None: prefix: str) -> None:
super().__init__() super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
moe_layer_start_index = fd_config.moe_config.moe_layer_start_index moe_layer_start_index = fd_config.moe_config.moe_layer_start_index
if isinstance(moe_layer_start_index, int): if isinstance(moe_layer_start_index, int):
text_moe_layer_start_index = moe_layer_start_index text_moe_layer_start_index = moe_layer_start_index
@@ -99,6 +102,7 @@ class Ernie4_5_VLMoE(nn.Layer):
} }
self.mlp_text = FusedMoE( self.mlp_text = FusedMoE(
fd_config=fd_config, fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.moe_config. moe_intermediate_size=fd_config.moe_config.
moe_intermediate_size[0], moe_intermediate_size[0],
num_experts=fd_config.moe_config.num_experts[0], num_experts=fd_config.moe_config.num_experts[0],
@@ -130,6 +134,7 @@ class Ernie4_5_VLMoE(nn.Layer):
} }
self.mlp_image = FusedMoE( self.mlp_image = FusedMoE(
fd_config=fd_config, fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.moe_config. moe_intermediate_size=fd_config.moe_config.
moe_intermediate_size[1], moe_intermediate_size[1],
num_experts=fd_config.moe_config.num_experts[1], num_experts=fd_config.moe_config.num_experts[1],
@@ -154,6 +159,7 @@ class Ernie4_5_VLMoE(nn.Layer):
intermediate_size=self.num_shared_experts * intermediate_size=self.num_shared_experts *
fd_config.moe_config.moe_intermediate_size[0], fd_config.moe_config.moe_intermediate_size[0],
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
reduce_results=False,
) )
def extract_gate_correction_bias_text(self, gate_correction_bias_key, def extract_gate_correction_bias_text(self, gate_correction_bias_key,
@@ -210,6 +216,8 @@ class Ernie4_5_VLMoE(nn.Layer):
hidden_states = self.mlp_text(hidden_states) hidden_states = self.mlp_text(hidden_states)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
hidden_states += share_experts_out hidden_states += share_experts_out
if self.tp_size > 1:
tensor_model_parallel_all_reduce(hidden_states)
return hidden_states return hidden_states
@@ -337,12 +345,12 @@ class Ernie4_5_VLModel(nn.Layer):
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"), prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"),
) )
self.hidden_layers = [ self.hidden_layers = nn.LayerList([
Ernie4_5_VLDecoderLayer( Ernie4_5_VLDecoderLayer(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}") prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
for i in range(self.num_layers) for i in range(self.num_layers)
] ])
self.norm = RMSNorm( self.norm = RMSNorm(
fd_config, fd_config,

View File

@@ -28,15 +28,17 @@ class ModelRegistry:
_registry = {} _registry = {}
@classmethod @classmethod
def register(cls, model_class): def register(cls, model_class, suffix=""):
"""register model class"""
if issubclass( if issubclass(
model_class, model_class,
ModelForCasualLM) and model_class is not ModelForCasualLM: ModelForCasualLM) and model_class is not ModelForCasualLM:
cls._registry[model_class.name()] = model_class cls._registry[f"{model_class.name()}{suffix}"] = model_class
return model_class return model_class
@classmethod @classmethod
def get_class(cls, name): def get_class(cls, name):
"""get model class"""
if name not in cls._registry: if name not in cls._registry:
raise ValueError(f"Model '{name}' is not registered!") raise ValueError(f"Model '{name}' is not registered!")
return cls._registry[name] return cls._registry[name]

View File

@@ -61,7 +61,7 @@ class Qwen2MLP(nn.Layer):
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
input_size=(fd_config.model_config.ffn_hidden_size // self.nranks), input_size=fd_config.model_config.ffn_hidden_size,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
with_bias=False, with_bias=False,
) )
@@ -97,8 +97,6 @@ class Qwen2Attention(nn.Layer):
prefix: str = "") -> None: prefix: str = "") -> None:
super().__init__() super().__init__()
nranks = fd_config.parallel_config.tensor_parallel_degree
self.qkv_proj = QKVParallelLinear(fd_config=fd_config, self.qkv_proj = QKVParallelLinear(fd_config=fd_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
with_bias=True) with_bias=True)
@@ -106,7 +104,7 @@ class Qwen2Attention(nn.Layer):
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
fd_config=fd_config, fd_config=fd_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
input_size=(fd_config.model_config.hidden_size // nranks), input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
) )

View File

@@ -68,7 +68,7 @@ class Qwen3Attention(nn.Layer):
fd_config=fd_config, fd_config=fd_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim * input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads // nranks, fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
) )

View File

@@ -63,7 +63,7 @@ class Qwen3MLP(nn.Layer):
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
fd_config, fd_config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
input_size=(fd_config.model_config.ffn_hidden_size // self.nranks), input_size=fd_config.model_config.ffn_hidden_size,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
with_bias=False, with_bias=False,
) )
@@ -111,7 +111,7 @@ class Qwen3Attention(nn.Layer):
fd_config, fd_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim * input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads // nranks, fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
) )

View File

@@ -0,0 +1,405 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import re
from enum import Enum
from functools import partial
from typing import Dict, List
import numpy as np
import paddle
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.conversion_utils import split_or_merge_func
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder
def check_tensor_parallel_prerequisites(
fd_config: FDConfig,
cls: PretrainedModel,
tensor_parallel_filtered_map: Dict[str, partial],
safetensor_keys: List[str],
) -> None:
"""check_tensor_parallel_prerequisites"""
if fd_config.parallel_config.tensor_parallel_degree > 1:
tensor_parallel_map = cls._get_tensor_parallel_mappings(
fd_config.model_config, is_split=True)
if not tensor_parallel_map:
logger.error("filtered_quant_map should not be empty. \
parallel splitting required, but _get_tensor_parallel_mappings is not implemented."
)
filtered_tp_keys = cls._resolve_prefix_keys(tensor_parallel_map.keys(),
safetensor_keys)
for k, v in filtered_tp_keys.items():
tensor_parallel_filtered_map[v] = tensor_parallel_map.pop(k)
if not tensor_parallel_filtered_map:
logger.error("tensor_parallel_filtered_map should not be empty. \
The weights required for tensor parallel splitting are inconsistent with the model's weights."
)
def extract_prefix(weight_name: str) -> str:
"""extract_prefix"""
if weight_name.startswith("."):
return ""
parts = weight_name.split(".", 1)
return parts[0] if len(parts) > 1 else ""
def has_prefix(prefix_name: str, weight_name: str):
"""has_prefix"""
return prefix_name == extract_prefix(weight_name)
class TensorSplitMode(Enum):
"""TensorSplitMode"""
GQA = "is_gqa"
TRANSPOSE = "transpose"
QKV = "is_old_qkv"
PairFused = "is_naive_2fuse"
TripletFused = "is_naive_3fuse"
def extract_placeholders(template: str):
"""extract_placeholders"""
return set(re.findall(r"{(\w+)}", template))
class SafeDict(dict):
"""SafeDict"""
def __missing__(self, key):
return "{" + key + "}"
def has_placeholders(placeholders):
"""has_placeholders"""
return len(placeholders) > 0
def update_final_actions(params, final_actions, key, action):
"""update_final_actions"""
new_key = key.format_map(SafeDict(params))
final_actions[new_key] = action
def build_expanded_keys(num_layers, num_experts, start_layer, base_actions):
"""build_expanded_keys"""
final_actions = {}
for key, action in base_actions.items():
placeholders = extract_placeholders(key)
if not has_placeholders(placeholders):
final_actions[key] = action
else:
if LayerIdPlaceholder.LAYER_ID.value in placeholders:
for layer_id in range(num_layers):
update_final_actions(
{LayerIdPlaceholder.LAYER_ID.value: layer_id},
final_actions,
key,
action,
)
elif LayerIdPlaceholder.FFN_LAYER_ID.value in placeholders:
for layer_id in range(start_layer):
update_final_actions(
{LayerIdPlaceholder.FFN_LAYER_ID.value: layer_id},
final_actions,
key,
action,
)
elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders
and LayerIdPlaceholder.EXPERT_ID.value in placeholders):
for layer_id in range(start_layer, num_layers):
for export_id in range(num_experts):
update_final_actions(
{
LayerIdPlaceholder.MOE_LAYER_ID.value:
layer_id,
LayerIdPlaceholder.EXPERT_ID.value: export_id,
},
final_actions,
key,
action,
)
elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders
and len(placeholders) == 1):
for layer_id in range(start_layer, num_layers):
update_final_actions(
{LayerIdPlaceholder.MOE_LAYER_ID.value: layer_id},
final_actions,
key,
action,
)
else:
logger.error(f"{key} does not match any case.")
return final_actions
def gqa_qkv_split_func(
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
head_dim,
):
"""
gqa_qkv_split_func
"""
def fn(x, is_column=True):
"""fucn"""
def get_shape(tensor):
"""get_shape"""
return tensor.get_shape() if hasattr(tensor,
"get_shape") else tensor.shape
def slice_tensor(tensor, start, end):
"""slice_tensor"""
shape = get_shape(tensor)
if len(shape) == 1:
return tensor[start:end]
elif is_column:
return tensor[..., start:end]
else:
return tensor[start:end, ...]
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(x, 0, q_end)
k = slice_tensor(x, q_end, k_end)
v = slice_tensor(x, k_end, v_end)
def split_tensor(tensor, degree):
"""
split_tensor
"""
shape = get_shape(tensor)
size = shape[-1] if is_column else shape[0]
block_size = size // degree
if hasattr(tensor, "get_shape"):
return [
slice_tensor(tensor, i * block_size, (i + 1) * block_size)
for i in range(degree)
]
else:
if isinstance(x, paddle.Tensor):
if is_column:
return paddle.split(tensor, degree, axis=-1)
else:
return paddle.split(tensor, degree, axis=0)
else:
if is_column:
return np.split(tensor, degree, axis=-1)
else:
return np.split(tensor, degree, axis=0)
q_list = split_tensor(q, tensor_parallel_degree)
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
if tensor_parallel_rank is None:
res = []
for q_i, k_i, v_i in zip(q_list, k_list, v_list):
if is_column:
if isinstance(x, paddle.Tensor):
res.append(paddle.concat([q_i, k_i, v_i], axis=-1))
else:
res.append(np.concatenate([q_i, k_i, v_i], axis=-1))
else:
if isinstance(x, paddle.Tensor):
res.append(paddle.concat([q_i, k_i, v_i], axis=0))
else:
res.append(np.concatenate([q_i, k_i, v_i], axis=0))
return res
else:
if isinstance(x, paddle.Tensor):
if is_column:
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=-1,
)
else:
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=0,
)
else:
if is_column:
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=-1,
)
else:
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=0,
)
return fn
def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
"""
gqa_qkv_merge_func
"""
def fn(weight_list, is_column=True):
"""fn"""
tensor_parallel_degree = len(weight_list)
num_attention_heads = num_attention_heads // tensor_parallel_degree
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
def get_shape(tensor):
"""
get_shape
"""
return tensor.get_shape() if hasattr(tensor,
"get_shape") else tensor.shape
def slice_tensor(tensor, start, end):
"""
slice_tensor
"""
if len(get_shape(tensor)) == 1:
return tensor[start:end]
elif is_column:
return tensor[..., start:end]
else:
return tensor[start:end, ...]
q_list, k_list, v_list = [], [], []
for weight in weight_list:
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(weight, 0, q_end)
k = slice_tensor(weight, q_end, k_end)
v = slice_tensor(weight, k_end, v_end)
q_list.append(q)
k_list.append(k)
v_list.append(v)
merged = q_list + k_list + v_list
if is_paddle_tensor:
if is_column:
tensor = paddle.concat(merged, axis=-1)
else:
tensor = paddle.concat(merged, axis=0)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor
else:
if is_column:
return np.concatenate(merged, axis=-1)
else:
return np.concatenate(merged, axis=0)
return fn
def split_or_merge_qkv_func(
is_split,
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
head_dim,
):
"""
split_or_merge_qkv_func
"""
if is_split:
return gqa_qkv_split_func(
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
)
else:
return gqa_qkv_merge_func(
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
)
def split_or_merge_func_v1(
is_split,
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads=None,
num_key_value_heads=None,
head_dim=None,
):
"""
split_or_merge_func_v1
"""
def fn(x, **kwargs):
"""func"""
is_gqa = kwargs.pop("is_gqa", False)
if is_gqa:
func = split_or_merge_qkv_func(
is_split=is_split,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
)
is_column = kwargs.pop("is_column", True)
return func(x, is_column=is_column)
else:
func = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
)
is_column = kwargs.pop("is_column", True)
is_naive_2fuse = kwargs.pop("is_naive_2fuse", False)
return func(x, is_column=is_column, is_naive_2fuse=is_naive_2fuse)
return fn

View File

@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import enum
import hashlib import hashlib
import json import json
import os import os
@@ -23,29 +24,47 @@ import random
import re import re
import struct import struct
from functools import partial from functools import partial
from typing import NamedTuple, Optional
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.common_ops_import import convert_dtype from paddle.common_ops_import import convert_dtype
from paddle.distributed import fleet from paddle.distributed import fleet
from paddleformers.transformers.model_utils import (_add_variant, from paddleformers.transformers.model_utils import _add_variant
load_tp_checkpoint)
from paddleformers.transformers.utils import paddleformers_load from paddleformers.transformers.utils import paddleformers_load
from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME, from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME) SAFE_WEIGHTS_INDEX_NAME)
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from safetensors import safe_open
from tqdm import tqdm from tqdm import tqdm
from fastdeploy.config import ModelConfig
MAX_BSZ = 512 MAX_BSZ = 512
MAX_DRAFT_TOKENS = 6 MAX_DRAFT_TOKENS = 6
class LayerIdPlaceholder(str, enum.Enum):
"""LayerIdPlaceholder"""
LAYER_ID = "layer_id"
FFN_LAYER_ID = "ffn_layer_id"
MOE_LAYER_ID = "moe_layer_id"
EXPERT_ID = "export_id"
class WeightMeta(NamedTuple):
"""
#tensor split parameters
# weight_name: weight name
# is_column: whether to split by columns
# extra: optional flags like "is_naive_2fuse", "is_gqa", "is_naive_3fuse"
"""
weight_name: str
is_column: bool
extra: Optional[str] = None
class UniqueIDGenerator: class UniqueIDGenerator:
""" """
The generator for the export model id The generator for the export model id
@@ -433,223 +452,6 @@ def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
return total_effective_tokens, total_tokens return total_effective_tokens, total_tokens
def load_ep_checkpoint(model_path: str,
config: ModelConfig,
return_numpy: bool = False,
return_key_name: bool = True):
"""
load ep checkpoint
"""
# return_numpy=True cpu
# return_numpy=False gpu
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
num_local_ffn_keys = []
for i in range(config.moe_layer_start_index, config.num_layers):
for j in range(
config.num_experts_start_offset,
config.num_experts_start_offset + config.num_experts_per_rank,
):
ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
ffn2_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
ffn2_quant_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
ffn2_scale_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
num_local_ffn_keys.append(ffn1_key)
num_local_ffn_keys.append(ffn2_key)
num_local_ffn_keys.append(ffn1_quant_key)
num_local_ffn_keys.append(ffn2_quant_key)
num_local_ffn_keys.append(ffn1_scale_key)
num_local_ffn_keys.append(ffn2_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:
filtered_map[k] = weight_list[k]
state_dict = {}
# Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths,
desc="Loading safetensor files",
unit="file"):
with safe_open(os.path.join(model_path, safetensor_path),
framework="np",
device="cpu") as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(
paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def get_safetensor_file(model_path):
"""
get_safetensor_file
"""
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
return key_name_list, safetensor_list
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def fastsafetensors_weights_iterator(safetensor_list: list[str]):
"""
fastsafetensors_weights_iterator
"""
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
world_size = dist.get_world_size()
if world_size > 1:
dist.init_parallel_env()
pg = dist.get_group()
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
else:
pg = SingleGroup()
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda(
) else "cpu"
safetensor_files_sub_lists = [
safetensor_list[i:i + world_size]
for i in range(0, len(safetensor_list), world_size)
]
for st_file in tqdm(
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
):
loader = SafeTensorsFileLoader(pg,
device,
nogds=True,
debug_log=False,
framework="paddle")
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def get_state_dict(model_path, config, use_fastsafetensor=False):
"""
get_state_dict
"""
state_dict = {}
_, safetensor_list = get_safetensor_file(
os.path.join(model_path, f"rank{config.tensor_parallel_rank}"))
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_list)
else:
weights_iterator = safetensors_weights_iterator(safetensor_list)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def apply_quant(name_action_quant_mappings, key, tensor, state_dict):
"""
apply_quant
"""
if key in name_action_quant_mappings:
action = name_action_quant_mappings.pop(key)
quant_weight_tensor, weight_scale_tensor = action(tensor)
if quant_weight_tensor is not None and weight_scale_tensor is not None:
state_dict[key + ".quant_weight"] = quant_weight_tensor
state_dict[key + ".weight_scale"] = weight_scale_tensor
else:
state_dict[key] = quant_weight_tensor
else:
state_dict[key] = tensor
def load_checkpoint(model_path, cls, config, return_numpy=True, load_gpu=True):
"""
load checkpoint
"""
if getattr(config, "parallel_config", None) is not None:
use_ep = getattr(config.parallel_config, "use_ep", False)
tensor_parallel_degree = config.parallel_config.tensor_parallel_degree
else:
use_ep = getattr(config, "use_ep", False)
tensor_parallel_degree = config.tensor_parallel_degree
if getattr(config, "model_config", None) is not None:
model_config = config.model_config
else:
model_config = config
if use_ep:
state_dict = load_ep_checkpoint(model_path,
config,
return_numpy=True,
return_key_name=True)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank")
and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if tensor_parallel_degree != len(rank_dirs):
raise ValueError(
f"Your model only supports loading with tp{len(rank_dirs)}"
)
state_dict = get_state_dict(model_path, model_config)
else:
state_dict = load_tp_checkpoint(model_path,
cls,
model_config,
return_numpy=return_numpy)
import re
for k, v in state_dict.items():
match = re.search(r'layers\.(\d+)', k)
if match and int(match.group(1)) > 0:
continue
return state_dict
def parser_quant_type(quant_type): def parser_quant_type(quant_type):
""" """
Parse the quantization type string and return the corresponding quantization types for weights, Parse the quantization type string and return the corresponding quantization types for weights,

View File

@@ -0,0 +1,354 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import importlib
import inspect
import os
import re
import sys
import paddle
import triton
from .triton_utils import (SubstituteTemplate, build_package, compile_file,
extract_triton_kernel, find_so_path,
get_pointer_hint, link_file, multi_process_do,
python_path, rename_c_to_cu)
def get_value_hint(x):
"""
Get the value hint from input list.
"""
hint = ""
for ele in x:
if isinstance(ele, int):
hint += "i64,"
continue
if ele % 16 == 0 and ele > 0:
hint += "i64:16,"
elif ele == 1:
hint += "i64:1,"
else:
hint += "i64,"
if isinstance(ele, float):
hint += "fp32,"
return hint
common_template = ("""
#include "${op_name}_kernel.h"
#include "paddle/extension.h"
void ${op_name}_func(${tensor_and_attr}) {
auto run_stream = a_ptr->stream();
auto res_flag = ${op_name}_kernel(run_stream, ${triton_kernel_args}, 0);
if (res_flag == CUDA_ERROR_INVALID_VALUE) {
PD_THROW("${op_name}_kernel failed");
}
}
PYBIND11_MODULE(${op_name}_package, m) {
m.def("${op_name}_func", ${op_name}_func, "get expert token num");
}
""")
class KernelInterface:
"""
triton kernel interface.
"""
def __init__(
self,
func,
other_config,
key_args=["1"],
):
"""
triton kernel interface.
"""
self.func = func
self.key_args = key_args
signature = inspect.signature(func)
self.arg_names = [v.name for v in signature.parameters.values()]
for ele in self.arg_names:
assert self.arg_names.count(ele) == 1
# arg_defaults = [v.default for v in signature.parameters.values()]
# self.annotations = {
# name: ty for name, ty in func.__annotations__.items()
# }
self.annotations = dict(func.__annotations__)
self.constexprs = [
self.arg_names.index(name) for name in self.arg_names
if self.annotations.get(name) == triton.language.core.constexpr
]
self.arg_exclude_constexpr = [
self.arg_names[i] for i in range(len(self.arg_names))
if i not in self.constexprs
]
import textwrap
py_script = textwrap.dedent(inspect.getsource(func))
pat = r"def\s" + func.__name__
func_begin = re.findall(pat, py_script)
assert len(func_begin) == 1
func_begin = func_begin[0]
py_script = py_script[py_script.find(func_begin):]
self.func_map = {}
def decorator(*args, **kwargs):
"""
decorator for triton kernels.
Args:
*args: positional arguments
**kwargs: keyword arguments
"""
op_name = "haha" + str(kwargs["N"])
if op_name in self.func_map.keys():
return self.func_map[op_name](*args)
all_input = []
for i in range(len(args)):
all_input.append(args[i])
position_arguments_num = len(all_input)
for i in range(position_arguments_num, len(self.arg_names)):
if self.arg_names[i] in kwargs.keys():
all_input.append(kwargs[self.arg_names[i]])
else:
# means this input is not specified, it muse be a tl.constexpr.
assert i in self.constexprs
all_input.append(None)
dtypes = []
x_list = []
const_args = [self.arg_names[i] for i in self.constexprs]
decalare_arg_exclude_constexpr = list(self.arg_exclude_constexpr)
passed_arg_exclude_constexpr = list(self.arg_exclude_constexpr)
const_hint_dict = {}
for i in range(len(all_input)):
ele = all_input[i]
if type(ele) in [
paddle.Tensor, paddle.base.framework.EagerParamBase,
paddle.base.framework.Parameter,
paddle.base.framework.Variable,
paddle.base.libpaddle.pir.Value,
type(None)
]:
if ele is not None:
dtypes.append(ele.dtype)
passed_arg_exclude_constexpr[
i] = f"(CUdeviceptr)({passed_arg_exclude_constexpr[i]}->data())"
else:
dtypes.append(paddle.int8)
passed_arg_exclude_constexpr[
i] = "(CUdeviceptr)(nullptr)"
decalare_arg_exclude_constexpr[
i] = "const paddle::optional<paddle::Tensor>&" + decalare_arg_exclude_constexpr[
i]
elif i in self.constexprs:
if isinstance(ele, bool):
const_hint_dict[self.arg_names[i]] = (int)(ele)
elif isinstance(ele, int):
if ele < 0:
const_hint_dict[self.arg_names[i]] = 0
else:
const_hint_dict[self.arg_names[i]] = ele
else:
assert False
else:
x_list.append(ele)
if isinstance(ele, int):
decalare_arg_exclude_constexpr[
i] = "const int64_t " + decalare_arg_exclude_constexpr[
i]
elif isinstance(ele, float):
decalare_arg_exclude_constexpr[
i] = "const float " + decalare_arg_exclude_constexpr[
i]
else:
assert False
python_package_name = f"{op_name}_package"
tp_rank = paddle.distributed.get_rank()
generated_dir = os.getenv("TRITON_KERNEL_CACHE_DIR",
f"/tmp/triton_cache/rank{tp_rank}")
print("the kernel cache dir is:", generated_dir)
generated_dir = f"{generated_dir}/{op_name}"
os.makedirs(generated_dir, exist_ok=True)
py_script_file = f"{generated_dir}/triton_kernels.py"
extract_triton_kernel(func, py_script_file)
address_hint = get_pointer_hint(dtypes)
value_hint = get_value_hint(x_list)
const_args = [f"{{{ele}}}" for ele in const_args]
const_args = ",".join(const_args)
lanuch_grid = list(self.grid)
for i in range(len(lanuch_grid)):
ele = lanuch_grid[i]
if isinstance(ele, str):
keys = list(const_hint_dict.keys())
keys.sort(key=len, reverse=True)
for key in keys:
if key in ele:
ele = ele.replace(key, f"{const_hint_dict[key]}")
else:
ele = str(ele)
lanuch_grid[i] = ele
if len(lanuch_grid) < 3:
lanuch_grid += ["1"] * (3 - len(lanuch_grid))
lanuch_grid = ",".join(lanuch_grid)
op_dict = {"op_name": op_name}
op_dict["triton_kernel_args"] = ",".join(
passed_arg_exclude_constexpr)
op_dict["tensor_and_attr"] = ",".join(
decalare_arg_exclude_constexpr)
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
so_path = find_so_path(generated_dir, python_package_name)
if so_path is None:
print("== we do not find so_path, we need to compile it")
with open(paddle_custom_op_file_path, "w") as f:
f.write(SubstituteTemplate(
common_template,
op_dict,
))
f.close()
# ahead of time compile command.
aot_template = (
f"""{python_path} {compile_file} {py_script_file} """ +
f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
+ f"""--out-name {op_name}_kernel """ +
""" -w {num_warps} -ns {num_stages} """ +
f""" -s"{address_hint} {value_hint} {const_args}" """ +
f""" -g "{lanuch_grid}" """)
all_tune_config = [const_hint_dict]
# reset const_hint_dict as empty.
const_hint_dict = {}
codegen_commands = []
for config in all_tune_config:
for key in const_hint_dict.keys():
if const_hint_dict[key] is not None:
if key not in config.keys():
config[key] = const_hint_dict[key]
else:
if config[key] == const_hint_dict[key]:
pass
else:
message = (
f"you specify {key} both in arguments and config, "
"and they are not same, this is wrong."
)
raise ValueError(message)
else:
assert key in config.keys(
), f"you must specify {key} in your config."
if "num_warps" not in config.keys():
config["num_warps"] = 4
if "num_stages" not in config.keys():
config["num_stages"] = 4
for key in config:
assert config[
key] is not None, f"{key} must be specified."
codegen_command = aot_template.format(**config, )
print(codegen_command)
codegen_commands.append(codegen_command)
multi_process_do(codegen_commands)
link_command = (
f"{python_path} {link_file} "
f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel")
re = os.system(link_command)
assert re == 0
# rename the .c file to .cu
rename_c_to_cu(generated_dir)
# build the package to so, not install
build_package(generated_dir, python_package_name)
# so_path have be found!
so_path = find_so_path(generated_dir, python_package_name)
print("== we find so_path: ", so_path)
assert so_path is not None
dir_path = os.path.dirname(so_path)
sys.path.append(dir_path)
lib = importlib.import_module(python_package_name)
pybind_func = getattr(lib, f"{op_name}_func")
self.func_map[op_name] = pybind_func
# run this op!
self.func_map[op_name](*args)
self.decorator = decorator
def __getitem__(self, op_name_and_grid):
"""
override the operator [], which will call the decorator function.
Args:
op_name_and_grid: the name of the operator and the grid size.
Returns:
the decorator function.
"""
self.grid = ((
"((max_possible_num_post_padded + BLOCK_SIZE_M -1)/ BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N-1) / BLOCK_SIZE_N)"
), )
return self.decorator
def paddle_use_triton_v2(other_config={}, key=[]):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
def decorator(func):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
return KernelInterface(func, other_config, key)
return decorator

View File

@@ -17,6 +17,7 @@ from typing import Dict, Optional
import paddle import paddle
from fastdeploy import envs
from fastdeploy.engine.config import SpeculativeConfig from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends, get_padding_offset, save_output, set_stop_value_multi_ends,
@@ -24,10 +25,11 @@ from fastdeploy.model_executor.ops.gpu import (
speculate_get_padding_offset, speculate_get_seq_lens_output, speculate_get_padding_offset, speculate_get_seq_lens_output,
speculate_save_output, speculate_set_value_by_flags_and_idx, speculate_save_output, speculate_set_value_by_flags_and_idx,
speculate_step_paddle, speculate_step_system_cache, speculate_update_v3, speculate_step_paddle, speculate_step_system_cache, speculate_update_v3,
step_paddle, step_system_cache, update_inputs) step_paddle, step_system_cache, update_inputs, step_reschedule)
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import ModelOutputData from fastdeploy.worker.output import ModelOutputData
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
def pre_process( def pre_process(
max_len: int, max_len: int,
@@ -214,6 +216,8 @@ def step_cuda(
""" """
TODO(gongshaotian): normalization name TODO(gongshaotian): normalization name
""" """
if speculative_config.method is not None: if speculative_config.method is not None:
if enable_prefix_caching: if enable_prefix_caching:
speculate_step_system_cache( speculate_step_system_cache(
@@ -291,6 +295,33 @@ def step_cuda(
share_inputs["input_ids"], share_inputs["pre_ids"], share_inputs["input_ids"], share_inputs["pre_ids"],
share_inputs["step_idx"], share_inputs["next_tokens"], share_inputs["step_idx"], share_inputs["next_tokens"],
share_inputs["first_token_ids"], block_size, enc_dec_block_num) share_inputs["first_token_ids"], block_size, enc_dec_block_num)
elif DISABLE_RECOVER:
step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
else: else:
step_paddle( step_paddle(
share_inputs["stop_flags"], share_inputs["stop_flags"],

View File

@@ -341,7 +341,6 @@ class TokenProcessor(object):
result.prompt = task.prompt result.prompt = task.prompt
result.prompt_token_ids = task.prompt_token_ids result.prompt_token_ids = task.prompt_token_ids
if recovery_stop: if recovery_stop:
result.outputs.token_ids.append(task.eos_token_ids[0])
result.error_msg = "Recover is not supported, the result is incomplete!" result.error_msg = "Recover is not supported, the result is incomplete!"
llm_logger.info( llm_logger.info(
f"Request: {task_id} finished, number of " f"Request: {task_id} finished, number of "

View File

@@ -15,11 +15,15 @@
platform interface file platform interface file
""" """
import paddle
import enum import enum
import paddle
class _Backend(enum.Enum): class _Backend(enum.Enum):
NATIVE_ATTN = enum.auto() NATIVE_ATTN = enum.auto()
APPEND_ATTN = enum.auto() APPEND_ATTN = enum.auto()
MLA_ATTN = enum.auto()
class Platform: class Platform:
@@ -71,8 +75,7 @@ class Platform:
if self.supported_quantization and quant not in self.supported_quantization: if self.supported_quantization and quant not in self.supported_quantization:
raise ValueError( raise ValueError(
f"{quant} quantization is currently not supported in " f"{quant} quantization is currently not supported in "
f"{self.device_name}." f"{self.device_name}.")
)
@classmethod @classmethod
def available(self): def available(self):

View File

@@ -46,7 +46,7 @@ class CUDAPlatform(Platform):
return False return False
@classmethod @classmethod
def get_attention_backend_cls(cls, selected_backend): def get_attention_backend_cls(cls, selected_backend: _Backend):
""" """
get_attention_backend_cls get_attention_backend_cls
""" """
@@ -60,5 +60,13 @@ class CUDAPlatform(Platform):
return ( return (
"fastdeploy.model_executor.layers.attention.AppendAttentionBackend" "fastdeploy.model_executor.layers.attention.AppendAttentionBackend"
) )
elif selected_backend == _Backend.MLA_ATTN:
logger.info("Using MLA ATTN backend.")
return (
"fastdeploy.model_executor.layers.attention.MLAAttentionBackend"
)
else: else:
logger.warning("Other backends are not supported for now.") raise ValueError(
"Invalid attention backend you specified.\n"
"Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place."
)

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from functools import cached_property from functools import cached_property
@@ -21,7 +20,6 @@ from typing import Callable, Optional, Union
from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest, from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from fastdeploy.utils import data_processor_logger
from fastdeploy.utils import is_list_of from fastdeploy.utils import is_list_of
@@ -120,7 +118,8 @@ class ReasoningParserManager:
reasoning_parsers: dict[str, type] = {} reasoning_parsers: dict[str, type] = {}
@classmethod @classmethod
def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: def get_reasoning_parser(cls,
name: Optional[str]) -> type[ReasoningParser]:
""" """
Get reasoning parser by name which is registered by `register_module`. Get reasoning parser by name which is registered by `register_module`.

20
fastdeploy/rl/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
from fastdeploy.model_executor.models import auto_models_registry
auto_models_registry(os.path.dirname(__file__), "fastdeploy.rl", suffix="RL")

View File

@@ -0,0 +1,282 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import time
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict, List
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.load_weight_utils import \
load_composite_checkpoint
from fastdeploy.model_executor.model_loader import MODEL_CLASSES
class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes."""
def __init__(self, fd_config: FDConfig, model: nn.Layer):
"""Initialize with config and model instances."""
self.fd_config = fd_config
self.load_config = fd_config.load_config
self.parallel_config = fd_config.parallel_config
self.state_dict: Dict[str, paddle.Tensor] = {}
self.rank = fd_config.parallel_config.tensor_parallel_rank
self.nranks = paddle.distributed.get_world_size()
self.meta_src_id = self._get_gpu_id()
self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.models: List[nn.Layer] = [model]
self._capture_model_state()
if self.load_config.load_strategy != "meta":
self.update_parameters()
logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
f" rank={self.rank}, ranks={self.nranks}, "
f" load ipc weight from {self.ipc_path}.")
@paddle.no_grad()
def _capture_model_state(self):
"""Capture and store initial model parameters state."""
for model in self.models:
for name, param in model.state_dict().items():
logger.debug(
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
)
self.state_dict[name] = param
def add_model(self, model: nn.Layer):
""""add model"""
self.models.append(model)
self._capture_model_state()
def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy."""
start_time = time.perf_counter()
paddle.device.cuda.empty_cache()
if not self.first_load:
paddle.distributed.restart_process_group()
strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot,
"ipc": self._update_ipc,
"ipc_no_reshard": self._update_ipc_no_reshard,
"normal": self.load_model,
}
if handler := strategy_handlers.get(self.load_config.load_strategy):
handler()
else:
raise ValueError(
f"Unsupported strategy: {self.load_config.load_strategy}")
logger.info(
f"Update parameters in {time.perf_counter()-start_time:.2f}s")
self._finalize_update(pid)
def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery."""
model_path = os.path.join(
self.parallel_config.model_name_or_path,
f"model_state.tp0{self.meta_src_id}.pdparams")
try:
ipc_state_dict = paddle.load(model_path)
except FileNotFoundError:
fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams"
ipc_state_dict = paddle.load(fallback_path)
try:
self._update_model_from_state(ipc_state_dict, "snapshot")
except Exception:
self.models[0].set_state_dict(ipc_state_dict)
logger.warning(
"load model from no_reshard weight, maybe need more GPU memory"
)
logger.info("IPC snapshot update parameters completed")
def _update_ipc(self):
"""Update using standard IPC strategy (requires Training Worker)."""
ipc_meta = paddle.load(self.ipc_path)
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
self._update_model_from_state(state_dict, "raw")
logger.info("IPC update parameters completed")
def _update_ipc_no_reshard(self):
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
ipc_meta = paddle.load(self.ipc_path)
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
self.models[0].set_state_dict(state_dict)
logger.info("IPC no-reshard update parameters completed")
def load_model(self) -> nn.Layer:
"""Standard model loading without IPC."""
architectures = self.fd_config.model_config.architectures[0]
model_class = MODEL_CLASSES[architectures]
state_dict = load_composite_checkpoint(
self.fd_config.parallel_config.model_name_or_path,
model_class,
self.fd_config.model_config,
return_numpy=True)
self.models[0].set_state_dict(state_dict)
logger.info("normal load update parameters completed")
def clear_parameters(self, pid: int = 0) -> None:
"""Clear all model parameters and free memory."""
logger.info("start clear paramaters")
paddle.device.cuda.empty_cache()
for model in self.models:
for param in model.state_dict().values():
param._clear_data()
self._verify_parameters("clearance")
if self.nranks > 1:
paddle.distributed.barrier()
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, -2)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor],
src_type: str):
"""Update model parameters from given state dictionary."""
update_count = 0
for name, new_param in state_dict.items():
if name not in self.state_dict:
logger.debug(f"Ignoring unmatched {src_type} param: {name}")
continue
target_param = self.state_dict[name]
self._validate_parameter_match(name, new_param, target_param)
new_param._share_buffer_to(target_param)
update_count += 1
logger.info(
f"🆗 Updated {update_count}/{len(state_dict)} parameters from {src_type} source"
)
def _validate_parameter_match(self, name: str, src: paddle.Tensor,
dst: paddle.Tensor):
"""验证参数一致性"""
if src.dtype != dst.dtype:
raise TypeError(
f"Type mismatch for {name}: {src.dtype} vs {dst.dtype}")
if src.shape != dst.shape:
raise ValueError(
f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
def _finalize_update(self, pid: int):
"""Finalize update process with verification."""
self._verify_parameters("update")
if self.nranks > 1:
paddle.distributed.barrier()
if not self.first_load:
self._update_shared_status(pid, 0)
self.first_load = False
def _get_gpu_id(self) -> int:
"""Get current GPU device ID."""
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0").split(",")
return int(visible_devices[int(os.getenv("FLAGS_selected_gpus", "0"))])
def _verify_parameters(self, operation: str):
"""Verify parameters are in expected state after operation."""
expected_initialized = (operation == "update")
all_valid = True
for name, param in self.state_dict.items():
is_initialized = param._is_initialized()
if is_initialized != expected_initialized:
logger.error(
f"Verification failed after {operation}: "
f"Param {name} initialized={is_initialized} (expected {expected_initialized})"
)
all_valid = False
if all_valid:
logger.info(f"💡 Model Parameter {operation} verified successfully")
else:
raise RuntimeError(
f"❌ Model Parameter {operation} verification failed")
@staticmethod
def _convert_ipc_meta_to_tensor(
ipc_meta: Dict[str, Any]) -> Dict[str, paddle.Tensor]:
"""Convert IPC metadata to tensor dictionary."""
converted = {}
for name, meta in ipc_meta.items():
meta[0] = meta[0].encode("latin-1")
meta[6] = int(os.getenv("FLAGS_selected_gpus", "0"))
tensor = paddle.base.core.LoDTensor._new_shared_cuda(tuple(meta))
converted[name] = paddle.to_tensor(tensor)
return converted
def _log_memory(self, context: str):
"""Log current GPU memory usage."""
max_alloc = paddle.device.cuda.max_memory_allocated() / (1024**3)
max_reserved = paddle.device.cuda.max_memory_reserved() / (1024**3)
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
logger.warning(f"GPU memory usage {context}:"
f"max_allocated: {max_alloc:.2f}GB\n"
f"max_reserved: {max_reserved:.2f}GB\n"
f"current_allocated: {curr_alloc:.2f}GB\n"
f"current_reserved: {curr_reserved:.2f}GB")
def _update_shared_status(self, pid: int, status: int) -> None:
"""Update shared memory status flag for inter-process communication."""
array = np.zeros([1], dtype=np.int32)
shm = SharedMemory(create=False,
size=array.nbytes,
name=f"model_weights_status.{pid}")
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
if self.rank == 0:
value[self.rank] = status
@staticmethod
def check_model_weights_status(model_weights_status, model_runner, pid):
"""
check model weights status
"""
is_stop = 0
while model_weights_status.value[0] != 0:
if model_weights_status.value[0] == 1:
logger.info(
"infer engine stopped! start to load new checkpoint...")
model_runner.update_parameters(pid)
elif model_weights_status.value[0] == -1:
logger.info(
"infer engine stopped! start to clear checkpoint...")
model_runner.clear_parameters(pid)
while True:
if model_weights_status.value[0] == 0:
logger.info("finished loading new checkpoint")
break
elif is_stop == 1 or (model_weights_status.value[0] == -2
and is_stop == 0):
if is_stop == 0:
logger.info("finished clearing checkpoint")
is_stop = 1
time.sleep(0.001)
break
else:
time.sleep(0.001)

View File

@@ -0,0 +1,215 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Dict
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.model_loader import ModelRegistry
from fastdeploy.model_executor.models.ernie4_5_moe import \
Ernie4_5_MoeForCausalLM
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
RL_MODEL_CLASSES = {
"Ernie4_5_MoeForCausalLMRL": Ernie4_5_MoeForCausalLM,
"Qwen2ForCausalLMRL": Qwen2PretrainedModel,
"Qwen3ForCausalLMRL": Qwen3PretrainedModel,
"Qwen3MoeForCausalLMRL": Qwen3MoePretrainedModel,
}
class RollOutModel(nn.Layer):
"""Main model class for rollout operations, supports multimodal components for train."""
def __init__(self, fd_config: FDConfig):
"""Initialize with FastDeploy configuration."""
super(RollOutModel, self).__init__()
self.fd_config = fd_config
self._init_models()
def _init_models(self):
"""Initialize all model components including multimodal if needed."""
self.is_vl = "VL" in self.fd_config.model_config.architectures[0]
self.rollout_model = self._load_primary_model()
self.rollout_models = [self.rollout_model]
if self.is_vl:
self._init_multimodal_models()
self.rollout_models.extend(
[self.vision_model, self.resampler_model])
def _init_multimodal_models(self):
"""Initialize vision and resampler components for multimodal models."""
# TODO:(gaoziyuan) Implement actual initialization
self.vision_model = nn.Layer()
self.resampler_model = nn.Layer()
def _load_primary_model(self):
"""Load main model from loader based on config."""
if "VL" in self.fd_config.model_config.architectures[0]:
logger.error("Loaded Vision Language model, not support now")
context = paddle.LazyGuard()
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(self.fd_config)
model.eval()
return model
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models."""
mappings = {}
for model in self.rollout_models:
mappings.update(
getattr(model, "get_name_mappings_to_training", lambda: {})())
return mappings
@paddle.no_grad()
def state_dict(self):
"""state_dict"""
all_params = {}
for model in self.rollout_models:
for name, param in model.state_dict().items():
logger.debug(
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
)
all_params[name] = param
return all_params
class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
"""
Ernie4_5_MoeForCausalLMRL
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Ernie4_5_MoeForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
"""name"""
return "Ernie4_5_MoeForCausalLMRL"
def get_name_mappings_to_training(self):
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
# Prepare placeholders
place_holders = ["weight"] + (["bias"] if have_bias else [])
# Initialize mapping dictionary
infer_to_train = {}
# Static mappings (non-layer specific)
static_mappings = {
"model.embeddings.word_embeddings.weight":
"ernie.embed_tokens.weight",
"model.norm.ln_weight": "ernie.norm.weight",
"lm_head.out_linear.weight": "lm_head.weight"
}
if self.fd_config.model_config.get("weight_sharing", False):
# Support tie_word_embeddings
logger.debug("enable tie_word_embeddings")
static_mappings.pop("lm_head.out_linear.weight")
infer_to_train.update(static_mappings)
infer_base_name = "model.hidden_layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx, is_moe_layer=False):
# Handle special case for layer 0's input layernorm
for ph in place_holders:
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
train_key = f"ernie.layers.{layer_idx}.input_layernorm.{ph}"
infer_to_train[infer_key] = train_key
# Common attention mappings
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
f"ernie.layers.{layer_idx}.self_attn.qkv_proj.{ph}"
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
f"ernie.layers.{layer_idx}.self_attn.o_proj.{ph}"
# Post-attention layernorm
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
f"ernie.layers.{layer_idx}.post_attention_layernorm.{ph}"
if not is_moe_layer:
# Dense FFN mappings
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
f"ernie.layers.{layer_idx}.mlp.up_gate_proj.{ph}"
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
f"ernie.layers.{layer_idx}.mlp.down_proj.{ph}"
else:
# MoE specific mappings
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
f"ernie.layers.{layer_idx}.mlp.gate.weight"
if self.fd_config.moe_config.moe_use_aux_free:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
f"ernie.layers.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
# Support shared experts
if self.fd_config.model_config.get(
"moe_num_shared_experts") > 0:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \
f"ernie.layers.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \
f"ernie.layers.{layer_idx}.mlp.shared_experts.down_proj.weight"
# MoE experts mappings
for expert_idx in range(self.fd_config.moe_config.num_experts):
for ph in place_holders:
# FFN1 (up_gate_proj)
ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn1_weight"
if ffn1_key not in infer_to_train:
infer_to_train[ffn1_key] = []
infer_to_train[ffn1_key].append(
f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
)
# FFN2 (down_proj)
ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn2_weight"
if ffn2_key not in infer_to_train:
infer_to_train[ffn2_key] = []
infer_to_train[ffn2_key].append(
f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
)
# Process non-MoE layers
for layer_idx in range(
self.fd_config.moe_config.moe_layer_start_index):
_add_layer_mappings(layer_idx, is_moe_layer=False)
# Process MoE layers
for layer_idx in range(self.fd_config.moe_config.moe_layer_start_index,
self.fd_config.model_config.num_layers):
_add_layer_mappings(layer_idx, is_moe_layer=True)
return infer_to_train

View File

@@ -260,12 +260,13 @@ class ResultReader(object):
ResultReader use an async thread to continue get infer result from redis ResultReader use an async thread to continue get infer result from redis
""" """
def __init__(self, client, idx, batch=200, ttl=900): def __init__(self, client, idx, batch=200, ttl=900, group=""):
self.idx = idx self.idx = idx
self.batch = batch self.batch = batch
self.client = client self.client = client
self.data = deque() self.data = deque()
self.ttl = ttl self.ttl = ttl
self.group = group
self.reqs = dict() self.reqs = dict()
self.out_buffer = dict() self.out_buffer = dict()
@@ -380,15 +381,18 @@ class ResultReader(object):
fetch infer results from redis for the give keys fetch infer results from redis for the give keys
""" """
total = 0 total = 0
if self.group != "":
keys = [self.group]
for key in keys: for key in keys:
#logger.info(f"Sync Results from Redis {key}")
results = self.client.rpop(key, self.batch) results = self.client.rpop(key, self.batch)
if results is None or len(results) == 0: if results is None or len(results) == 0:
continue continue
#logger.info(f"Rpop {self.idx}: {len(results)}") #logger.info(f"Rpop {key} {self.idx}: {len(results)}")
total += len(results) total += len(results)
for result in results: for result in results:
try: try:
#logger.info(f"Scheduler Get Results: {result}") # logger.info(f"Scheduler Get Results: {result.request_id}")
data = orjson.loads(result) data = orjson.loads(result)
result = RequestOutput.from_dict(data) result = RequestOutput.from_dict(data)
self.data.appendleft(result) self.data.appendleft(result)
@@ -425,8 +429,9 @@ class APIScheduler(object):
start backup threads start backup threads
""" """
for i in range(self.reader_parallel): for i in range(self.reader_parallel):
group = f"{self.nodeid}-{i}"
reader = ResultReader(self.client, i, self.reader_batch_size, reader = ResultReader(self.client, i, self.reader_batch_size,
self.ttl) self.ttl, group)
self.readers.append(reader) self.readers.append(reader)
self.clear_expired_nodes_thread = threading.Thread( self.clear_expired_nodes_thread = threading.Thread(
@@ -481,15 +486,16 @@ class APIScheduler(object):
reader = self.readers[reader_idx] reader = self.readers[reader_idx]
reader.add_req(req) reader.add_req(req)
group = self.readers[reader_idx].group
reader_idx = (reader_idx + 1) % len(self.readers) reader_idx = (reader_idx + 1) % len(self.readers)
self.schedule(req, pnodes, dnodes, mnodes) self.schedule(req, pnodes, dnodes, mnodes, group)
except IndexError: except IndexError:
continue continue
except Exception as e: except Exception as e:
logger.error(f"APIScheduler Schedule req error: {str(e)}") logger.error(f"APIScheduler Schedule req error: {str(e)}")
def schedule(self, req, pnodes, dnodes, mnodes): def schedule(self, req, pnodes, dnodes, mnodes, group=""):
""" """
schedule an req to according redis node queue schedule an req to according redis node queue
""" """
@@ -498,7 +504,9 @@ class APIScheduler(object):
pnode = self.select_pd(req, pnodes, "prefill") pnode = self.select_pd(req, pnodes, "prefill")
if pnode.role == "mixed": if pnode.role == "mixed":
req.disaggregate_info = None req.disaggregate_info = None
req_str = orjson.dumps(req.to_dict()) req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
pkey = f"ReqQ_{pnode.nodeid}" pkey = f"ReqQ_{pnode.nodeid}"
#logger.info(f"Schedule Req {req_str} to Mixed") #logger.info(f"Schedule Req {req_str} to Mixed")
self.client.lpush(pkey, req_str) self.client.lpush(pkey, req_str)
@@ -518,7 +526,9 @@ class APIScheduler(object):
disaggregated["transfer_protocol"] = transfer_protocol[0] disaggregated["transfer_protocol"] = transfer_protocol[0]
req.disaggregate_info = disaggregated req.disaggregate_info = disaggregated
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}" pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
req_str = orjson.dumps(req.to_dict()) req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
#logger.info(f"Schedule Req {req_str}") #logger.info(f"Schedule Req {req_str}")
self.client.lpush(dkey, req_str) self.client.lpush(dkey, req_str)
self.client.lpush(pkey, req_str) self.client.lpush(pkey, req_str)
@@ -634,7 +644,9 @@ class ResultWriter(object):
size = len(self.data) size = len(self.data)
if size == 0: if size == 0:
self.cond.wait() self.cond.wait()
#qsize = size
size = min(size, self.batch) size = min(size, self.batch)
#logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
groups = dict() groups = dict()
for i in range(size): for i in range(size):
key, item = self.data.pop() key, item = self.data.pop()
@@ -749,12 +761,13 @@ class InferScheduler(object):
for req_str in reqs: for req_str in reqs:
req = orjson.loads(req_str) req = orjson.loads(req_str)
group = req.get("group", "")
req = Request.from_dict(req) req = Request.from_dict(req)
writer_idx = select_writer(req) writer_idx = select_writer(req)
logger.info( logger.info(
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}" f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
) )
req.request_id = f"{req.request_id}#{writer_idx}" req.request_id = f"{req.request_id}#{writer_idx}#{group}"
if self.role == "prefill" or self.role == "mixed": if self.role == "prefill" or self.role == "mixed":
self.reqs_queue.append(req) self.reqs_queue.append(req)
self.node.add_req(req.request_id, self.node.add_req(req.request_id,
@@ -813,10 +826,10 @@ class InferScheduler(object):
req_ids.add(result.request_id) req_ids.add(result.request_id)
req_id, idx = result.request_id.split("#") req_id, idx, group = result.request_id.split("#")
result.request_id = req_id result.request_id = req_id
key = (req_id, int(idx)) key = (req_id if group == "" else group, int(idx))
if key not in groups: if key not in groups:
groups[key] = list() groups[key] = list()

View File

@@ -1,15 +0,0 @@
export FLAGS_use_pd_disaggregation=1
export INFERENCE_MSG_QUEUE_ID=1
export FD_LOG_DIR="log_decode"
CUDA_VISIBLE_DEVICES=4,5,6,7 python fastdeploy.entrypoints.openai.api_server.py --config test.yaml --port 9812 --max-num-seqs 256 --kv-cache-ratio 0.8 --splitwise-role "decode" --engine-worker-queue-port 6678 --innode-prefill-ports 6677 --cache-queue-port 55667 --enable-prefix-caching --enable-chunked-prefill &
export FD_LOG_DIR="log_prefill"
export INFERENCE_MSG_QUEUE_ID=3
export FLAGS_fmt_write_cache_completed_signal=1
export PREFILL_NODE_ONE_STEP_STOP=1
CUDA_VISIBLE_DEVICES=0,1,2,3 python fastdeploy.entrypoints.openai.api_server.py --config test.yaml --port 9811 --cpu-offload-gb 5 --max-num-seqs 16 --kv-cache-ratio 0.9 --splitwise-role "prefill" --engine-worker-queue-port 6677 --enable-prefix-caching --cache-queue-port 55663 &

View File

@@ -1,4 +1,4 @@
model: "baidu/ERNIE-45-300B-A47B-Paddle" model: "baidu/paddle_internal/ERNIE-45-Turbo"
max_model_len: 32768 max_model_len: 32768
max_num_seqs: 128 max_num_seqs: 128
kv_cache_ratio: 0.5 kv_cache_ratio: 0.5

View File

@@ -84,9 +84,10 @@ def replicate_experts(
return phy2log, rank, logcnt return phy2log, rank, logcnt
def rebalance_experts_hierarchical(weight: np.ndarray, def rebalance_experts_hierarchical(
num_physical_experts: int, num_groups: int, weight: np.ndarray, num_physical_experts: int, num_groups: int,
num_nodes: int, num_gpus: int): num_nodes: int,
num_gpus: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
Parameters: Parameters:
weight: [num_moe_layers, num_logical_experts] weight: [num_moe_layers, num_logical_experts]

View File

@@ -1,7 +1,7 @@
""" """
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
""" """redundant expert manger."""
redundant expert manger from typing import Optional, Tuple
"""
import numpy as np import numpy as np
import paddle import paddle
@@ -29,9 +28,9 @@ class RedundantExpertManger:
RedundantExpertManger RedundantExpertManger
""" """
def __init__(self, n_routed_experts, num_hidden_layers, def __init__(self, n_routed_experts: int, num_hidden_layers: int,
redundant_experts_num, ep_size): redundant_experts_num: int, ep_size: int) -> None:
"""Initialize a redundant expert manager"""
self.num_expert = n_routed_experts self.num_expert = n_routed_experts
self.redundant_experts_num = redundant_experts_num self.redundant_experts_num = redundant_experts_num
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
@@ -94,7 +93,9 @@ class RedundantExpertManger:
num_replicas {self.num_replicas} export_per_rank {self.export_per_rank}" num_replicas {self.num_replicas} export_per_rank {self.export_per_rank}"
) )
def get_ep_rank_to_expert_id_list_by_layer(self, layer_id): def get_ep_rank_to_expert_id_list_by_layer(
self, layer_id: int
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" """
get_ep_rank_to_expert_id_list_by_layer get_ep_rank_to_expert_id_list_by_layer
""" """
@@ -103,7 +104,9 @@ class RedundantExpertManger:
self.model_expert_in_rank_num_list[layer_id], \ self.model_expert_in_rank_num_list[layer_id], \
self.model_tokens_per_expert_stats_list[layer_id] self.model_tokens_per_expert_stats_list[layer_id]
def get_ep_rank_to_expert_id_list(self, layer_id): def get_ep_rank_to_expert_id_list(
self, layer_id: int
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" """
get_ep_rank_to_expert_id_list get_ep_rank_to_expert_id_list
""" """
@@ -112,9 +115,12 @@ class RedundantExpertManger:
self.model_expert_in_rank_num_list[layer_id], \ self.model_expert_in_rank_num_list[layer_id], \
self.model_tokens_per_expert_stats_list[layer_id] self.model_tokens_per_expert_stats_list[layer_id]
def get_expert_tokens_stats(self, def get_expert_tokens_stats(
self,
verbose: bool = False, verbose: bool = False,
clear_stat: bool = False): clear_stat: bool = False
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray],
Optional[np.ndarray]]:
""" """
get_per_expert_tokens_stats get_per_expert_tokens_stats
""" """
@@ -130,7 +136,7 @@ class RedundantExpertManger:
if clear_stat: if clear_stat:
self.model_tokens_per_expert_stats_list.zero_() self.model_tokens_per_expert_stats_list.zero_()
def get_expert_id_to_ep_rank_array(self): def get_expert_id_to_ep_rank_array(self) -> np.ndarray:
""" """
get_expert_id_to_ep_rank_array get_expert_id_to_ep_rank_array
""" """
@@ -140,7 +146,7 @@ class RedundantExpertManger:
rank_expert_list: np.ndarray, rank_expert_list: np.ndarray,
logical_to_physical_map: np.ndarray, logical_to_physical_map: np.ndarray,
expert_count: np.ndarray, expert_count: np.ndarray,
clear_stat: bool = True): clear_stat: bool = True) -> None:
""" """
update_expert_rank_table update_expert_rank_table
""" """

View File

@@ -330,6 +330,8 @@ class ForwardMeta():
decoder_batch_ids: Optional[paddle.Tensor] = None decoder_batch_ids: Optional[paddle.Tensor] = None
# for attention backend # for attention backend
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# is_decode_batch or not
is_decode_batch: bool = False
@classmethod @classmethod
def init_forward_meta(cls, share_inputs: Dict, def init_forward_meta(cls, share_inputs: Dict,
@@ -357,6 +359,11 @@ class ForwardMeta():
) )
return ret return ret
def clear_caches(self):
"""safe clear caches"""
if self.caches:
del self.caches
@dataclass @dataclass
class XPUForwardMeta(ForwardMeta): class XPUForwardMeta(ForwardMeta):

View File

@@ -20,6 +20,7 @@ from typing import List, Optional
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
@@ -41,13 +42,10 @@ from fastdeploy.model_executor.pre_and_post_process import (post_process,
rebuild_padding, rebuild_padding,
step_cuda) step_cuda)
from fastdeploy.spec_decode import MTPProposer, NgramProposer from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.utils import get_logger
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta
from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
logger = get_logger("gpu_model_runner", "gpu_model_runner.log")
class GPUModelRunner(ModelRunnerBase): class GPUModelRunner(ModelRunnerBase):
""" """ """ """
@@ -593,6 +591,10 @@ class GPUModelRunner(ModelRunnerBase):
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
# 1. Load original model # 1. Load original model
self.model = get_model_from_loader(fd_config=self.fd_config) self.model = get_model_from_loader(fd_config=self.fd_config)
# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
# 2. Load lora model # 2. Load lora model
@@ -621,6 +623,25 @@ class GPUModelRunner(ModelRunnerBase):
for attn_backend in self.attn_backends: for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta) attn_backend.init_attention_metadata(self.forward_meta)
def clear_cache(self):
"""Clear cached data from shared inputs and forward metadata."""
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
self.forward_meta.clear_caches()
def clear_parameters(self, pid):
""""dynamic model loader use to clear parameters use for RL"""
self.dynamic_weight_manager.clear_parameters(pid)
self.clear_cache()
paddle.device.cuda.empty_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
def update_parameters(self, pid):
""""dynamic model loader use to update parameters use for RL"""
self.dynamic_weight_manager.update_parameters(pid)
self.initialize_kv_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
def initialize_kv_cache(self) -> None: def initialize_kv_cache(self) -> None:
""" """
Initialize kv cache Initialize kv cache
@@ -691,15 +712,14 @@ class GPUModelRunner(ModelRunnerBase):
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend( attn_cls = get_attention_backend()
self.parallel_config.attention_backend)
attn_backend = attn_cls(self.fd_config, attn_backend = attn_cls(self.fd_config,
kv_num_heads=self.model_config.kv_num_heads, kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads, num_heads=num_heads,
head_dim=head_dim) head_dim=head_dim)
if attn_backend is None: if attn_backend is None:
raise NotImplementedError( raise NotImplementedError(
f"{ self.parallel_config.attention_backend} attention backend is not support by GPUModelRunner" "Attention backend which you chose is not support by GPUModelRunner"
) )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)
@@ -735,6 +755,7 @@ class GPUModelRunner(ModelRunnerBase):
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0) > 1).sum() > 0)
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
self.forward_meta.is_decode_batch = is_decode_batch
model_output = self.model( model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"], ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta) forward_meta=self.forward_meta)
@@ -967,6 +988,7 @@ class GPUModelRunner(ModelRunnerBase):
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0) > 1).sum() > 0)
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
self.forward_meta.is_decode_batch = is_decode_batch
model_output = self.model( model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"], ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta) forward_meta=self.forward_meta)
@@ -1124,9 +1146,7 @@ class GPUModelRunner(ModelRunnerBase):
batch_size=min(self.parallel_config.max_num_seqs, 3)) batch_size=min(self.parallel_config.max_num_seqs, 3))
# 3. gc # 3. gc
del self.share_inputs["caches"] self.clear_cache()
if self.forward_meta is not None:
del self.forward_meta.caches
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.clear_dummy_input() self.proposer.clear_dummy_input()

View File

@@ -16,10 +16,12 @@
import json import json
import os import os
import random import random
import argparse
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddleformers.transformers.model_utils import load_tp_checkpoint
from safetensors import safe_open from safetensors import safe_open
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
@@ -38,11 +40,13 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope.modeling import \
DFNRopeVisionTransformerPretrainedModel DFNRopeVisionTransformerPretrainedModel
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ( from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import (
ScatterOp, VariableResolutionResamplerModel) ScatterOp, VariableResolutionResamplerModel)
from fastdeploy.model_executor.models.utils import load_checkpoint
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta
from fastdeploy.worker.utils import check_safetensors_model from fastdeploy.worker.utils import check_safetensors_model
from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase
from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig,
LoadConfig, ModelConfig, MoEConfig,
MoEPhase, ParallelConfig, SpeculativeConfig)
if current_platform.is_cuda() and current_platform.available(): if current_platform.is_cuda() and current_platform.available():
from fastdeploy.model_executor.layers.utils import ( from fastdeploy.model_executor.layers.utils import (
@@ -55,8 +59,20 @@ from fastdeploy.model_executor.ops.gpu import (save_output,
class GPUVLModelRunner(VLModelRunnerBase): class GPUVLModelRunner(VLModelRunnerBase):
"""
The GPUVLModelRunner class for vision-language tasks on GPU.
"""
def __init__(self, config, args, nranks, rank): def __init__(
self,
config: ModelConfig,
args: argparse.Namespace,
nranks: int,
rank: int,
) -> None:
"""
GPUVLModelRunner init
"""
self.nranks = nranks self.nranks = nranks
self.rank = rank self.rank = rank
@@ -104,14 +120,11 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.sampler = Sampler() self.sampler = Sampler()
def _reset_paddle_env(self): def _reset_paddle_env(self):
#FLAGS_gqa_use_tensorcore
#FLAGS_ffn2_use_hardamard
# gqa .etc paddle Flags set
pass pass
def update_chunked_prefill(self, tasks): def update_chunked_prefill(self, tasks: list[any]) -> None:
""" """
更新chunked prefill相关参数 update chunked prefill
""" """
if not self.args.enable_chunked_prefill: if not self.args.enable_chunked_prefill:
return return
@@ -135,7 +148,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
"image_features"] = self.extract_vision_features( "image_features"] = self.extract_vision_features(
inputs) inputs)
else: else:
# 兼容没有图片和视频的情况 # Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None self.share_inputs["image_features"] = None
token_chunk_size = inputs["input_ids"].shape[1] token_chunk_size = inputs["input_ids"].shape[1]
@@ -152,7 +165,14 @@ class GPUVLModelRunner(VLModelRunnerBase):
task.start_idx += token_chunk_size task.start_idx += token_chunk_size
task.chunk_idx += 1 task.chunk_idx += 1
def _load_model(self, model_name, dynamic_load_weight): def _load_model(
self,
model_name: str,
dynamic_load_weight: int = 0,
) -> None:
"""
Load the model from the given model name.
"""
vocab_file_names = [ vocab_file_names = [
"tokenizer.model", "spm.model", "ernie_token_100k.model" "tokenizer.model", "spm.model", "ernie_token_100k.model"
@@ -261,7 +281,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
fd_config.parallel_config.max_model_len = fd_config.model_config.max_seq_len fd_config.parallel_config.max_model_len = fd_config.model_config.max_seq_len
self.fd_config = fd_config self.fd_config = fd_config
attn_backend_cls = get_attention_backend(self.args.attention_backend) attn_backend_cls = get_attention_backend()
num_heads = self.fd_config.model_config.num_attention_heads // \ num_heads = self.fd_config.model_config.num_attention_heads // \
self.fd_config.parallel_config.tensor_parallel_degree self.fd_config.parallel_config.tensor_parallel_degree
self.fd_config.model_config.kv_num_heads = int( self.fd_config.model_config.kv_num_heads = int(
@@ -275,7 +295,10 @@ class GPUVLModelRunner(VLModelRunnerBase):
head_dim=head_dim) head_dim=head_dim)
self._init_kvcache() self._init_kvcache()
def init_extra_input(self, config, args): def init_extra_input(self, config: ModelConfig, args: argparse.Namespace) -> None:
"""
Initialize extra input tensors.
"""
head_dim = self.model_cfg.head_dim head_dim = self.model_cfg.head_dim
self.share_inputs.update({ self.share_inputs.update({
"rope_emb": "rope_emb":
@@ -287,29 +310,31 @@ class GPUVLModelRunner(VLModelRunnerBase):
}) })
self.share_inputs.update({"image_features": None}) self.share_inputs.update({"image_features": None})
self.share_inputs.update({ self.share_inputs.update({
"need_think_end": paddle.full(shape=[ "need_think_end":
args.max_num_seqs, 1], paddle.full(shape=[args.max_num_seqs, 1],
fill_value=0, fill_value=0,
dtype="int32") dtype="int32")
}) })
self.share_inputs.update({ self.share_inputs.update({
"enable_thinking": paddle.full(shape=[1], "enable_thinking":
fill_value=True, paddle.full(shape=[1], fill_value=True, dtype="bool")
dtype="bool")
}) })
self.share_inputs.update({ self.share_inputs.update({
"reasoning_index": paddle.full(shape=[ "reasoning_index":
args.max_num_seqs, 1], paddle.full(shape=[args.max_num_seqs, 1],
fill_value=0, fill_value=0,
dtype="int32") dtype="int32")
}) })
def init_rotary_position_embedding(self, max_model_len): def init_rotary_position_embedding(self, max_model_len: int) -> None:
"""
Init rotary position embedding
"""
pass pass
def _init_kvcache(self): def _init_kvcache(self):
""" """
分享不拷贝数据 Init kv cache
""" """
cache_kvs = {} cache_kvs = {}
total_block_num = self.num_gpu_blocks total_block_num = self.num_gpu_blocks
@@ -352,7 +377,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
del value del value
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
def clear_parameters(self, pid): def clear_parameters(self, pid: int) -> None:
""" clear_parameters """ """ clear_parameters """
if "caches" in self.share_inputs: if "caches" in self.share_inputs:
self.model.clear_parameters(pid) self.model.clear_parameters(pid)
@@ -360,7 +385,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
self.model.log_memory_usage("clear all memory") self.model.log_memory_usage("clear all memory")
def update_parameters(self, pid): def update_parameters(self, pid: int) -> None:
""" update_parameters """ """ update_parameters """
if "caches" not in self.share_inputs: if "caches" not in self.share_inputs:
self.model.update_parameters(pid) self.model.update_parameters(pid)
@@ -368,7 +393,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.model.log_memory_usage("update all memory") self.model.log_memory_usage("update all memory")
@paddle.no_grad() @paddle.no_grad()
def set_state_dict(self, args): def set_state_dict(self, args: argparse.Namespace) -> None:
"""set_state_dict""" """set_state_dict"""
if not self.is_safetensors_model: if not self.is_safetensors_model:
rank_model_paths = [] rank_model_paths = []
@@ -401,7 +426,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.model.set_state_dict(state_dict) self.model.set_state_dict(state_dict)
self.resampler_model.set_state_dict(resampler_state) self.resampler_model.set_state_dict(resampler_state)
else: else:
state_dict = load_checkpoint( state_dict = load_tp_checkpoint(
args.model_name_or_path, args.model_name_or_path,
Ernie4_5_PretrainedModel, Ernie4_5_PretrainedModel,
self.model_cfg, self.model_cfg,
@@ -414,10 +439,14 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.model.set_state_dict(state_dict) self.model.set_state_dict(state_dict)
@paddle.no_grad() @paddle.no_grad()
def vit_load(self, model_path, tensor_parallel_degree, def vit_load(
tensor_parallel_rank): self,
model_path: str,
tensor_parallel_degree: int,
tensor_parallel_rank: int,
) -> None:
""" """
vit_load tp参数 Load vit tp weight
""" """
if tensor_parallel_degree == 1: if tensor_parallel_degree == 1:
rank_model_path = os.path.join(model_path, "model_state.pdparams") rank_model_path = os.path.join(model_path, "model_state.pdparams")
@@ -430,15 +459,18 @@ class GPUVLModelRunner(VLModelRunnerBase):
raise ValueError(f"No such a file {rank_model_path}") raise ValueError(f"No such a file {rank_model_path}")
@paddle.no_grad() @paddle.no_grad()
def inject_pp_vision_model(self, args, cfg): def inject_pp_vision_model(self, args: argparse.Namespace, cfg: Ernie4_5_VLMoeConfig):
""" """
注入vision model参数 Inject pp vision model
""" """
def set_vision_state_dict(model, def set_vision_state_dict(model,
tensor_parallel_degree=8, tensor_parallel_degree: int=8,
tensor_parallel_rank=0, tensor_parallel_rank: int=0,
name=""): name: str=""):
"""
Set vision model weight
"""
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
compat_keys = [name + k for k in model_state_dict.keys()] compat_keys = [name + k for k in model_state_dict.keys()]
model_files = set() model_files = set()
@@ -543,7 +575,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
return vision_model, resampler_model return vision_model, resampler_model
@paddle.no_grad() @paddle.no_grad()
def extract_vision_features(self, inputs): def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features""" """extract_vision_features"""
assert inputs["images"] is not None assert inputs["images"] is not None
grid_thw = inputs["grid_thw"] grid_thw = inputs["grid_thw"]
@@ -585,7 +617,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
return image_features return image_features
@paddle.no_grad() @paddle.no_grad()
def prepare_rope3d(self, position_ids, **kwargs): def prepare_rope3d(self, position_ids: paddle.Tensor, **kwargs) -> paddle.Tensor:
"""prepare_rope3d""" """prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1 prefix_max_position_ids = paddle.max(position_ids) + 1
@@ -608,13 +640,13 @@ class GPUVLModelRunner(VLModelRunnerBase):
def prefill_finished(self): def prefill_finished(self):
""" """
判断是否已经完成了prefill操作 Verify prefill operation completion
""" """
prefill_statue = (self.share_inputs["seq_lens_this_time"] != 0) & ( prefill_statue = (self.share_inputs["seq_lens_this_time"] != 0) & (
self.share_inputs["seq_lens_this_time"] != 1) self.share_inputs["seq_lens_this_time"] != 1)
return not paddle.any(prefill_statue).numpy() return not paddle.any(prefill_statue).numpy()
def dy_input_preprocess(self, tasks): def dy_input_preprocess(self, tasks: list[any]) -> None:
""" """
dynamic insertion dynamic insertion
""" """
@@ -662,7 +694,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
"image_features"] = self.extract_vision_features( "image_features"] = self.extract_vision_features(
inputs) inputs)
else: else:
# 兼容没有图片和视频的情况 # Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None self.share_inputs["image_features"] = None
if task.multimodal_inputs["position_ids"] is not None: if task.multimodal_inputs["position_ids"] is not None:
position_ids = paddle.to_tensor( position_ids = paddle.to_tensor(
@@ -688,7 +720,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
"image_features"] = self.extract_vision_features( "image_features"] = self.extract_vision_features(
inputs) inputs)
else: else:
# 兼容没有图片和视频的情况 # Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"] position_ids = inputs["position_ids"]
@@ -702,10 +734,11 @@ class GPUVLModelRunner(VLModelRunnerBase):
# force </think> # force </think>
self.share_inputs["enable_thinking"][:] = kwargs["enable_thinking"] self.share_inputs["enable_thinking"][:] = kwargs["enable_thinking"]
self.share_inputs["need_think_end"][idx:idx + self.share_inputs["need_think_end"][
1, :] = 1 if kwargs["enable_thinking"] else 0 idx:idx + 1, :] = 1 if kwargs["enable_thinking"] else 0
self.share_inputs["reasoning_index"][idx:idx + 1, :] = kwargs["reasoning_max_tokens"] self.share_inputs["reasoning_index"][
idx:idx + 1, :] = kwargs["reasoning_max_tokens"]
self.share_inputs["rope_emb"][idx:idx + self.share_inputs["rope_emb"][idx:idx +
1, :] = self.prepare_rope3d( 1, :] = self.prepare_rope3d(
@@ -737,7 +770,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
idx:idx + 1, :encoder_block_num] = np.array(task.block_tables, idx:idx + 1, :encoder_block_num] = np.array(task.block_tables,
dtype="int32") dtype="int32")
def pre_process(self): def pre_process(self) -> None:
""" """
pre_process pre_process
""" """
@@ -794,7 +827,10 @@ class GPUVLModelRunner(VLModelRunnerBase):
eos_token_ids=self.share_inputs["eos_token_id"], eos_token_ids=self.share_inputs["eos_token_id"],
) )
def generate(self): def generate(self) -> None:
"""
generate
"""
self.pre_process() self.pre_process()
hiddden_states = self.model(self.share_inputs["ids_remove_padding"], hiddden_states = self.model(self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"], self.share_inputs["image_features"],
@@ -815,7 +851,10 @@ class GPUVLModelRunner(VLModelRunnerBase):
paddle.distributed.broadcast(next_tokens, 0) paddle.distributed.broadcast(next_tokens, 0)
self.post_process(next_tokens) self.post_process(next_tokens)
def post_process(self, next_tokens): def post_process(self, next_tokens: paddle.Tensor) -> None:
"""
post_process
"""
if self.share_inputs["enable_thinking"]: if self.share_inputs["enable_thinking"]:
exists_think_end = next_tokens == self.model_cfg.think_end_id exists_think_end = next_tokens == self.model_cfg.think_end_id
paddle.assign( paddle.assign(
@@ -823,37 +862,28 @@ class GPUVLModelRunner(VLModelRunnerBase):
exists_think_end, exists_think_end,
self.share_inputs["need_think_end"] - 1, self.share_inputs["need_think_end"] - 1,
self.share_inputs["need_think_end"], self.share_inputs["need_think_end"],
), ), self.share_inputs["need_think_end"])
self.share_inputs["need_think_end"]
)
paddle.assign( paddle.assign(
paddle.where( paddle.where(
self.share_inputs["need_think_end"].cast("bool"), self.share_inputs["need_think_end"].cast("bool"),
self.share_inputs["reasoning_index"] - 1, self.share_inputs["reasoning_index"] - 1,
self.share_inputs["reasoning_index"], self.share_inputs["reasoning_index"],
), ), self.share_inputs["reasoning_index"])
self.share_inputs["reasoning_index"]
)
stop_wo_think = ( stop_wo_think = (
( (next_tokens == self.share_inputs["eos_token_id"]) |
next_tokens == self.share_inputs["eos_token_id"] (self.share_inputs["reasoning_index"] == 0)) & (
) | ( self.share_inputs["need_think_end"] > 0)
self.share_inputs["reasoning_index"] == 0 next_tokens = paddle.where(stop_wo_think,
) self.model_cfg.think_end_id,
) & ( next_tokens)
self.share_inputs["need_think_end"] > 0
)
next_tokens = paddle.where(stop_wo_think, self.model_cfg.think_end_id, next_tokens)
paddle.assign( paddle.assign(
paddle.where( paddle.where(
stop_wo_think, stop_wo_think,
self.share_inputs["need_think_end"] - 1, self.share_inputs["need_think_end"] - 1,
self.share_inputs["need_think_end"], self.share_inputs["need_think_end"],
), ), self.share_inputs["need_think_end"])
self.share_inputs["need_think_end"]
)
paddle.assign( paddle.assign(
paddle.where( paddle.where(
self.share_inputs["stop_flags"], self.share_inputs["stop_flags"],
@@ -899,14 +929,13 @@ class GPUVLModelRunner(VLModelRunnerBase):
def _cal_theortical_kvcache(self): def _cal_theortical_kvcache(self):
""" """
计算理论的kvcache大小 Calculate the size of kvcache for computational theory
""" """
num_layers = self.model_cfg.get("num_layers", num_layers = self.model_cfg.get("num_layers",
None) or self.model_cfg.get( None) or self.model_cfg.get(
"num_hidden_layers", None) "num_hidden_layers", None)
byte_of_cache = 2 byte_of_cache = 2
#TODO # support c8 c4
# 支持c8 c4
hidden_dim = self.model_cfg.head_dim * self.model_cfg.kv_num_head hidden_dim = self.model_cfg.head_dim * self.model_cfg.kv_num_head
theoretical_kv_cache_memory = (2 * byte_of_cache * theoretical_kv_cache_memory = (2 * byte_of_cache *
@@ -915,6 +944,9 @@ class GPUVLModelRunner(VLModelRunnerBase):
return theoretical_kv_cache_memory return theoretical_kv_cache_memory
def _update_share_input_block_num(self): def _update_share_input_block_num(self):
"""
Update share_inputs['block_tables'] and share_inputs['free_list']
"""
num_gpu_blocks = self.num_gpu_blocks num_gpu_blocks = self.num_gpu_blocks
del self.share_inputs["caches"] del self.share_inputs["caches"]
@@ -924,7 +956,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.share_inputs["block_tables"] = paddle.full( self.share_inputs["block_tables"] = paddle.full(
[self.args.max_num_seqs, num_gpu_blocks], -1, dtype="int32") [self.args.max_num_seqs, num_gpu_blocks], -1, dtype="int32")
# 初始化free list # Init free list
free_list = list( free_list = list(
range(num_gpu_blocks - 1, range(num_gpu_blocks - 1,
int(num_gpu_blocks * self.args.kv_cache_ratio) - 1, -1)) int(num_gpu_blocks * self.args.kv_cache_ratio) - 1, -1))
@@ -936,7 +968,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
paddle.full([1], self.free_list_len, dtype="int32"), paddle.full([1], self.free_list_len, dtype="int32"),
}) })
def dummy_input(self, num_total_tokens, number_of_tasks): def dummy_input(self, num_total_tokens: int, number_of_tasks: int) -> None:
""" """
fake input to profile fake input to profile
""" """
@@ -974,7 +1006,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \ self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \
(idx + 1) * block_num, 1) (idx + 1) * block_num, 1)
def _preprocess_task(self, one): def _preprocess_task(self, one: dict) -> None:
"""process batch""" """process batch"""
input_ids = one["input_ids"][np.newaxis, :] input_ids = one["input_ids"][np.newaxis, :]
@@ -1012,13 +1044,13 @@ class GPUVLModelRunner(VLModelRunnerBase):
def build_stream_line_model( def build_stream_line_model(
model_path, model_path: str,
dtype, dtype: str,
block_size, block_size: int,
max_model_len, max_model_len: int,
tokenizer, tokenizer: ErnieBotTokenizer,
quantization: str = "None", quantization: str = "None",
): ) -> tuple[FDConfig, paddle.nn.layer]:
""" """
build model build model
""" """
@@ -1028,9 +1060,6 @@ def build_stream_line_model(
from paddleformers.trl import llm_utils from paddleformers.trl import llm_utils
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig,
LoadConfig, ModelConfig, MoEConfig,
MoEPhase, ParallelConfig, SpeculativeConfig)
from fastdeploy.model_executor.layers.quantization import \ from fastdeploy.model_executor.layers.quantization import \
get_quantization_config get_quantization_config
from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.model_base import ModelRegistry

View File

@@ -15,10 +15,12 @@
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import argparse
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from fastdeploy.config import ModelConfig
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -27,20 +29,20 @@ logger = get_logger("worker", "worker.log")
class VLModelRunnerBase(ABC): class VLModelRunnerBase(ABC):
""" """
Initializes the model and sets up necessary parameters. Engine -> (WIP)Executor -> Worker -> VLModelRunnerBase -> Model
VLModelRunnerBase interface abstracts the model execution logic that
Args: contain input preparation, token generation, and tokenprocessing.
config (Config): The configuration object for the model. """
args (Namespace): The arguments passed to the script.
def __init__(
Returns: self,
None. config: ModelConfig,
args: argparse.Namespace,
Raises: ) -> None:
None. """
VLModelRunnerBase init
""" """
def __init__(self, config, args):
self.share_inputs = {} self.share_inputs = {}
self.model_cfg = config self.model_cfg = config
self.args = args self.args = args
@@ -66,7 +68,7 @@ class VLModelRunnerBase(ABC):
f"current_allocated: {curr_alloc:.2f}GB\n" f"current_allocated: {curr_alloc:.2f}GB\n"
f"current_reserved: {curr_reserved:.2f}GB") f"current_reserved: {curr_reserved:.2f}GB")
def init_dist_env(self, seed=20): def init_dist_env(self, seed=20) -> None:
""" """
init distributed env init distributed env
""" """
@@ -85,7 +87,7 @@ class VLModelRunnerBase(ABC):
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index() self.rank = fleet.worker_index()
def _load_model_init_val(self): def _load_model_init_val(self) -> None:
""" """
initialize model config from config file initialize model config from config file
""" """
@@ -105,18 +107,10 @@ class VLModelRunnerBase(ABC):
self.min_length = _get_attr("min_length", 1) self.min_length = _get_attr("min_length", 1)
self.max_length = self.args.max_model_len self.max_length = self.args.max_model_len
def _init_share_inputs(self, max_num_seqs): def _init_share_inputs(self, max_num_seqs: int) -> None:
""" """
初始化共享的输入,包括预测和训练。 initialize shared inputs
将所有需要的张量都初始化为零或者特定值。
Args:
max_num_seqs (int): 最大批次大小,用于初始化张量。
Returns:
None.
""" """
# 统一使用paddle.full创建张量
self._load_model_init_val() self._load_model_init_val()
int64_config = {"dtype": "int64"} int64_config = {"dtype": "int64"}
@@ -124,7 +118,6 @@ class VLModelRunnerBase(ABC):
float32_config = {"dtype": "float32"} float32_config = {"dtype": "float32"}
bool_config = {"dtype": "bool"} bool_config = {"dtype": "bool"}
# 批量初始化张量
self.share_inputs.update({ self.share_inputs.update({
"pre_ids": "pre_ids":
paddle.full([max_num_seqs, self.max_length], -1, **int64_config), paddle.full([max_num_seqs, self.max_length], -1, **int64_config),
@@ -146,7 +139,6 @@ class VLModelRunnerBase(ABC):
"presence_score": "presence_score":
paddle.full([max_num_seqs, 1], self.presence_score, paddle.full([max_num_seqs, 1], self.presence_score,
**float32_config), **float32_config),
# TODO 名称统一
"min_dec_len": "min_dec_len":
paddle.full([max_num_seqs, 1], self.min_length, **int64_config), paddle.full([max_num_seqs, 1], self.min_length, **int64_config),
"max_dec_len": "max_dec_len":
@@ -207,14 +199,12 @@ class VLModelRunnerBase(ABC):
paddle.full([max_num_seqs, 1], -1, **int32_config), paddle.full([max_num_seqs, 1], -1, **int32_config),
}) })
# 计算block tables相关参数
pre_max_block_num = ( pre_max_block_num = (
self.args.max_model_len + self.args.block_size - self.args.max_model_len + self.args.block_size -
1) // self.args.block_size + self.args.enc_dec_block_num 1) // self.args.block_size + self.args.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full( self.share_inputs["block_tables"] = paddle.full(
[max_num_seqs, pre_max_block_num], -1, **int32_config) [max_num_seqs, pre_max_block_num], -1, **int32_config)
# 初始化free list
free_list = list( free_list = list(
range( range(
self.args.total_block_num - 1, self.args.total_block_num - 1,
@@ -228,7 +218,6 @@ class VLModelRunnerBase(ABC):
paddle.full([1], self.free_list_len, **int32_config), paddle.full([1], self.free_list_len, **int32_config),
}) })
# 初始化stop seqs
self.share_inputs.update({ self.share_inputs.update({
"stop_seqs_len": "stop_seqs_len":
paddle.full([self.model_cfg.max_stop_seqs_num], 0, **int32_config), paddle.full([self.model_cfg.max_stop_seqs_num], 0, **int32_config),
@@ -239,9 +228,9 @@ class VLModelRunnerBase(ABC):
], -1, **int64_config), ], -1, **int64_config),
}) })
def update_chunked_prefill(self, tasks): def update_chunked_prefill(self, tasks: list[any]) -> None:
""" """
更新chunked prefill相关参数 update chunked prefill
""" """
if not self.args.enable_chunked_prefill: if not self.args.enable_chunked_prefill:
return return
@@ -251,58 +240,38 @@ class VLModelRunnerBase(ABC):
def prefill_finished(self): def prefill_finished(self):
""" """
判断是否已经完成了prefill操作 Verify prefill operation completion
""" """
return True return True
@abstractmethod @abstractmethod
def init_rotary_position_embedding(self, max_model_len): def init_rotary_position_embedding(self, max_model_len: int) -> None:
""" """
初始化旋转位置编码,需要重写该方法。 Init rotary position embedding
参数max_model_lenint序列的最大长度。
返回值None无返回值需要在方法内完成初始化操作。
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def _load_model(self, model_dir, dynamic_load_weight): def _load_model(
self,
model_name: str,
dynamic_load_weight: int = 0,
) -> None:
""" """
加载模型,包括模型参数和优化器等。 Load the model from the given model name.
需要子类实现该方法。
Args:
model_dir (str): 模型保存的目录路径。
Raises:
NotImplementedError: 当前方法未被实现。
Returns:
None.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def _init_kvcache(self): def _init_kvcache(self):
""" """
初始化kv缓存用于快速查找数据块。 Init kv cache
该方法需要被子类实现。
Args:
max_block_num (int): 最大的数据块数量。
Raises:
NotImplementedError: 当该方法未被子类实现时会引发此异常。
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def dy_input_preprocess(self): def dy_input_preprocess(self, tasks: list[any]) -> None:
""" """
预处理输入数据用于计算dy。 dynamic insertion
该函数需要在每次forward之前调用并且只能调用一次。
默认实现抛出NotImplementedError。子类可以根据具体的模型实现此功能。
Raises:
NotImplementedError: 如果没有实现该方法。
""" """
raise NotImplementedError raise NotImplementedError

View File

@@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet
from fastdeploy.engine.config import ModelConfig from fastdeploy.engine.config import ModelConfig
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.utils import get_logger, none_or_str from fastdeploy.utils import get_logger, none_or_str
from fastdeploy.worker.worker_process import initialize_fd_config, parse_args
logger = get_logger("worker", "worker.log") logger = get_logger("worker", "worker.log")
@@ -35,7 +36,14 @@ class PrefillTracker:
Record the prefill time of the request Record the prefill time of the request
""" """
def __init__(self, engine_pid): def __init__(
self,
engine_pid: int,
) -> None:
"""
Initialize the PrefillTracker.
"""
super().__init__()
self.start_times = defaultdict(float) self.start_times = defaultdict(float)
prefill_time_data = np.zeros([100], dtype=np.float32) prefill_time_data = np.zeros([100], dtype=np.float32)
self.prefill_time_signal = IPCSignal(name="prefill_time_signal", self.prefill_time_signal = IPCSignal(name="prefill_time_signal",
@@ -46,7 +54,7 @@ class PrefillTracker:
self.current_index = 0 self.current_index = 0
self.executor = ThreadPoolExecutor(max_workers=1) self.executor = ThreadPoolExecutor(max_workers=1)
def start_prefill(self, task_idx): def start_prefill(self, task_idx: int):
""" """
Record the start time of the prefill process for a given task index. Record the start time of the prefill process for a given task index.
@@ -55,7 +63,7 @@ class PrefillTracker:
""" """
self.start_times[task_idx] = time.time() self.start_times[task_idx] = time.time()
def end_prefill(self, task_idx): def end_prefill(self, task_idx: int):
""" """
Record the end time of the prefill process for a given task index and Record the end time of the prefill process for a given task index and
asynchronously submit the duration for metric recording. asynchronously submit the duration for metric recording.
@@ -69,7 +77,7 @@ class PrefillTracker:
self.executor.submit(self._record_metrics, duration) self.executor.submit(self._record_metrics, duration)
del self.start_times[task_idx] del self.start_times[task_idx]
def _record_metrics(self, duration): def _record_metrics(self, duration: float):
""" """
Internal method to record the prefill duration into the signal buffer. Internal method to record the prefill duration into the signal buffer.
Logs the duration and updates a circular buffer of timing metrics. Logs the duration and updates a circular buffer of timing metrics.
@@ -89,19 +97,19 @@ class PrefillTracker:
class Worker: class Worker:
def __init__(self, args):
""" """
Args: Engine -> (WIP)Executor -> Worker -> ModelRunner -> Model
args (ArgumentParser): 命令行参数,包含模型名称、端口号等信息。 Worker interface that allows inference framwork to cleanly separate implementations for different harware.
Returns:
None, 无返回值,初始化完成后会将相关参数和对象保存到类属性中。
Raises:
None, 没有异常抛出。
""" """
def __init__(
self,
args,
) -> None:
"""
Initialize the Worker.
"""
super().__init__()
self.args = args self.args = args
self.MAX_INFER_SEED = 9223372036854775806 self.MAX_INFER_SEED = 9223372036854775806
paddle.set_default_dtype(args.dtype) paddle.set_default_dtype(args.dtype)
@@ -123,7 +131,7 @@ class Worker:
rank=self.rank) rank=self.rank)
self.prefill_tracker = PrefillTracker(args.engine_pid) self.prefill_tracker = PrefillTracker(args.engine_pid)
# TODO 多机 # Only applicable for standalone (single-machine) inference
address = ('0.0.0.0', self.args.engine_worker_queue_port) address = ('0.0.0.0', self.args.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
@@ -154,7 +162,10 @@ class Worker:
self.rank = fleet.worker_index() self.rank = fleet.worker_index()
def init_health(self): def init_health(self):
# worker_ready_signal 用于engine感知各worker进程是否Ready """
init health signals
"""
# To perceive whether each worker process is ready
worker_ready_signal_data = np.zeros(shape=[self.nranks], worker_ready_signal_data = np.zeros(shape=[self.nranks],
dtype=np.int32) dtype=np.int32)
self.worker_ready_signal = IPCSignal(name="worker_ready_signal", self.worker_ready_signal = IPCSignal(name="worker_ready_signal",
@@ -164,7 +175,7 @@ class Worker:
create=False) create=False)
self.worker_ready_signal.value[self.rank] = 1 self.worker_ready_signal.value[self.rank] = 1
# worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间 # To monitor the liveness of worker processes and record each step's timestamp
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.nranks], worker_healthy_live_recorded_time_array = np.zeros(shape=[self.nranks],
dtype=np.int32) dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal( self.worker_healthy_live_signal = IPCSignal(
@@ -175,7 +186,7 @@ class Worker:
create=False) create=False)
self.worker_healthy_live_signal.value[self.rank] = int(time.time()) self.worker_healthy_live_signal.value[self.rank] = int(time.time())
# exist_task_signal 用于各worker进程感知是否有新Task需要处理 # To perceive whether there is a new task to be processed
exist_task_signal_data = np.zeros([1], dtype=np.int32) exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(name="exist_task_signal", self.exist_task_signal = IPCSignal(name="exist_task_signal",
array=exist_task_signal_data, array=exist_task_signal_data,
@@ -183,7 +194,7 @@ class Worker:
suffix=self.args.engine_pid, suffix=self.args.engine_pid,
create=False) create=False)
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task # To detect whether there are swapped tasks in the worker
exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32) exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal( self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal", name="exist_swapped_task_signal",
@@ -192,7 +203,6 @@ class Worker:
suffix=self.args.engine_pid, suffix=self.args.engine_pid,
create=False) create=False)
# model_weights_status 用于engine感知各worker中模型权重状态
model_weights_status = np.zeros([1], dtype=np.int32) model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal( self.model_weights_status_signal = IPCSignal(
name="model_weights_status", name="model_weights_status",
@@ -309,17 +319,7 @@ class Worker:
def run(self): def run(self):
""" """
运行函数,不断地从队列中获取任务并进行推理。 run function, continuously get tasks and do inference.
当队列为空或者所有节点都处于等待状态时,将会休眠一段时间再次尝试获取任务。
Args:
None.
Returns:
None.
Raises:
None.
""" """
infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1], infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1],
fill_value=4, fill_value=4,
@@ -526,153 +526,6 @@ class Worker:
break break
def parse_args():
"""
parse args from command line
"""
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
parser.add_argument("-m",
"--model_name_or_path",
type=str,
default="./output",
help="model dir")
parser.add_argument("-mbs",
"--max_num_seqs",
type=int,
default=34,
help="max batch size")
parser.add_argument("--total_block_num", type=int, default=2000)
parser.add_argument("--block_size", type=int, default=64)
parser.add_argument("--engine_worker_queue_port", type=int, default=9923)
parser.add_argument("--max_model_len",
type=int,
default=3072,
help="max model len")
parser.add_argument("--device_ids",
type=str,
default="0",
help="cuda visible devices")
parser.add_argument("--dtype",
type=str,
default="bfloat16",
help="input dtype")
parser.add_argument("--enc_dec_block_num",
type=int,
default=1,
help="encoder's decoder num")
parser.add_argument("--kv_cache_ratio",
type=float,
default=0.7,
help="kv cache ratio for input")
parser.add_argument("--first_token_id",
type=int,
default=1,
help="first token id")
parser.add_argument("--gpu_memory_utilization",
type=float,
default=0.9,
help="gpu memory utilization")
parser.add_argument("--engine_pid",
type=int,
default=None,
help="Process ID of engine")
parser.add_argument("--do_profile",
action='store_true',
help="do profile or not")
parser.add_argument("--dynamic_load_weight",
action='store_true',
help="dynamic load weight or not")
parser.add_argument("--pad_token_id",
type=int,
default=-1,
help="pad token id")
parser.add_argument("--eos_tokens_lens",
type=int,
default=2,
help="eos token lens")
parser.add_argument("--enable_chunked_prefill",
action='store_true',
help="enable chunked prefill")
parser.add_argument(
"--speculative_method",
default=None,
type=none_or_str,
choices=[None, "ngram", "mtp"],
)
parser.add_argument(
"--speculative_max_draft_token_num",
default=1,
type=int,
)
parser.add_argument(
"--speculative_model_name_or_path",
default="",
type=str,
)
parser.add_argument(
"--speculative_model_quantization",
default="",
type=str,
)
parser.add_argument(
"--attention_backend",
default="APPEND_ATTN",
type=str,
choices=[
"APPEND_ATTN",
],
)
parser.add_argument("--max_num_batched_tokens",
type=int,
default=2048,
help="max num batched tokens")
parser.add_argument("--enable_prefix_caching",
action='store_true',
help="enable prefix cache")
parser.add_argument("--splitwise_role",
type=str,
default="mixed",
help="splitwise role")
parser.add_argument("--ori_vocab_size", type=int, default=None)
parser.add_argument("--tensor_parallel_size",
type=int,
default=1,
help="tensor parallel size")
parser.add_argument("--expert_parallel_size",
type=int,
default=1,
help="expert parallel size")
parser.add_argument("--quantization",
type=str,
default="",
help="Quantization name for the model, currentlly support " \
"'wint4', 'wint8'," \
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
parser.add_argument("--enable_static_graph_inference",
action='store_true',
help="Whether to use static mode; if enabled, " \
"'paddle.to_static' will be used to convert dynamic to static.")
parser.add_argument("--use_cudagraph",
action='store_true',
help="Flags to enable cuda graph.")
parser.add_argument("--max_capture_batch_size",
type=int,
default=64,
help="Maximum of Batch Size for Warm Up.")
parser.add_argument("--guided_decoding_backend",
type=str,
default="off",
help="guided decoding backend")
parser.add_argument("--disable_any_whitespace",
action='store_false',
help="Disable any whitespace for guided decoding.")
args = parser.parse_args()
return args
def main(): def main():
""" """
start worker start worker

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" """
import argparse import argparse
import json
import time import time
from typing import List from typing import List
@@ -23,6 +22,7 @@ import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from fastdeploy import envs
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig, from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
GraphOptimizationConfig, LoadConfig, GraphOptimizationConfig, LoadConfig,
ModelConfig, MoEConfig, MoEPhase, ModelConfig, MoEConfig, MoEPhase,
@@ -61,14 +61,21 @@ class PaddleDisWorkerProc():
def __init__( def __init__(
self, self,
fd_config: FDConfig, fd_config: FDConfig,
): ) -> None:
"""
Initialize a distributed worker and task queue for single-node multi-GPU setup.
Args:
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
"""
self.fd_config = fd_config self.fd_config = fd_config
self.parallel_config = fd_config.parallel_config self.parallel_config = fd_config.parallel_config
# Initialize distributed enviroment # Initialize distributed enviroment
(self.rank, self.local_rank) = self.init_distributed_enviroment() (self.ranks, self.local_rank) = self.init_distributed_enviroment()
assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.rank assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.ranks
self.fd_config.parallel_config.tensor_parallel_rank = \ self.fd_config.parallel_config.tensor_parallel_rank = \
self.local_rank % self.parallel_config.tensor_parallel_degree self.local_rank % self.parallel_config.tensor_parallel_degree
@@ -81,8 +88,6 @@ class PaddleDisWorkerProc():
self.fd_config.moe_config.num_experts_start_offset = \ self.fd_config.moe_config.num_experts_start_offset = \
self.fd_config.parallel_config.expert_parallel_rank * self.fd_config.moe_config.num_experts_per_rank self.fd_config.parallel_config.expert_parallel_rank * self.fd_config.moe_config.num_experts_per_rank
self.fd_config.parallel_config.column_cut = False
# For auto TP split # For auto TP split
self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_degree self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_degree
self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
@@ -95,7 +100,7 @@ class PaddleDisWorkerProc():
# TODO(gongshaotian): Use worker factory to get worker # TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config, self.worker = get_worker(fd_config=fd_config,
local_rank=self.local_rank, local_rank=self.local_rank,
rank=self.rank) rank=self.ranks)
# Initialize task queue # Initialize task queue
task_address = ('0.0.0.0', task_address = ('0.0.0.0',
@@ -109,7 +114,7 @@ class PaddleDisWorkerProc():
local_data_parallel_id=self.fd_config.parallel_config. local_data_parallel_id=self.fd_config.parallel_config.
expert_parallel_rank) expert_parallel_rank)
def init_health_status(self): def init_health_status(self) -> None:
""" """
Initialize the health status of the worker. Initialize the health status of the worker.
Worker Status: Worker Status:
@@ -134,7 +139,7 @@ class PaddleDisWorkerProc():
self.worker_ready_signal.value[self.local_rank % 8] = 1 self.worker_ready_signal.value[self.local_rank % 8] = 1
# init worker_healthy_live_signal # init worker_healthy_live_signal
workers_alive = np.zeros(shape=[self.rank], dtype=np.int32) workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal( self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal", name="worker_healthy_live_signal",
array=workers_alive, array=workers_alive,
@@ -183,16 +188,7 @@ class PaddleDisWorkerProc():
suffix=self.parallel_config.engine_pid, suffix=self.parallel_config.engine_pid,
create=False) create=False)
# init model_weights_status def event_loop_ep(self) -> None:
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
self.model_weights_status = IPCSignal(
name="model_weights_status",
array=workers_model_weights,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
def event_loop_ep(self):
""" """
Tmp loop function for ep utill DP is supported Tmp loop function for ep utill DP is supported
""" """
@@ -217,7 +213,7 @@ class PaddleDisWorkerProc():
# These generated tokens can be obtained through get_output op. # These generated tokens can be obtained through get_output op.
self.worker.execute_model() self.worker.execute_model()
def event_loop_normal(self): def event_loop_normal(self) -> None:
""" Main event loop for Paddle Distrubuted Workers. """ Main event loop for Paddle Distrubuted Workers.
TODO(gongshaotian): support remote calling of functions that control worker. TODO(gongshaotian): support remote calling of functions that control worker.
""" """
@@ -225,6 +221,12 @@ class PaddleDisWorkerProc():
self.nnode = 1 self.nnode = 1
req_ids = [] req_ids = []
while True: while True:
if self.local_rank == 0:
if self.model_weights_status.value[0] != 0:
self.exist_task_signal.value[0] = 2
else:
self.exist_task_signal.value[0] = 0
if self.parallel_config.tensor_parallel_degree > 1: if self.parallel_config.tensor_parallel_degree > 1:
# Synchronize before updating weights # Synchronize before updating weights
paddle.distributed.barrier() paddle.distributed.barrier()
@@ -234,7 +236,7 @@ class PaddleDisWorkerProc():
time.time()) time.time())
# The first worker detects whether there are tasks in the task queue # The first worker detects whether there are tasks in the task queue
mp_num_per_node = self.rank / self.nnode mp_num_per_node = self.ranks / self.nnode
if self.local_rank % mp_num_per_node == 0: if self.local_rank % mp_num_per_node == 0:
if self.task_queue.num_tasks() > 0: if self.task_queue.num_tasks() > 0:
if self.nnode > 1: if self.nnode > 1:
@@ -249,6 +251,14 @@ class PaddleDisWorkerProc():
# TODO(@wufeisheng): Split TP group and EP group # TODO(@wufeisheng): Split TP group and EP group
paddle.distributed.barrier() paddle.distributed.barrier()
if self.fd_config.load_config.dynamic_load_weight:
if self.exist_task_signal.value[0] == 2:
from fastdeploy.rl.dynamic_weight_manager import \
DynamicWeightManager
DynamicWeightManager.check_model_weights_status(
self.model_weights_status, self.worker.model_runner,
self.parallel_config.engine_pid)
if self.exist_task_signal.value[ if self.exist_task_signal.value[
self.fd_config.parallel_config.expert_parallel_rank] == 1 or \ self.fd_config.parallel_config.expert_parallel_rank] == 1 or \
self.task_queue.read_finish_flag.get() == 1: self.task_queue.read_finish_flag.get() == 1:
@@ -275,7 +285,7 @@ class PaddleDisWorkerProc():
self.worker.preprocess_new_task(req_dicts) self.worker.preprocess_new_task(req_dicts)
if not self.worker.model_runner.not_need_stop(): if not self.worker.model_runner.not_need_stop():
if self.rank > 1: if self.ranks > 1:
paddle.distributed.barrier() paddle.distributed.barrier()
time.sleep(0.001) time.sleep(0.001)
@@ -288,15 +298,15 @@ class PaddleDisWorkerProc():
self.exist_prefill_task_signal.value[ self.exist_prefill_task_signal.value[
0] = self.worker.prefill_finished() 0] = self.worker.prefill_finished()
def init_distributed_enviroment(self, seed=20) -> List[int]: def init_distributed_enviroment(self, seed: int = 20) -> List[int]:
""" Initialize Paddle Fleet and get rank of worker """ """ Initialize Paddle Fleet and get rank of worker """
# Global rank # Global rank
self.rank = dist.get_world_size() self.ranks = dist.get_world_size()
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.hybrid_configs = { dist_strategy.hybrid_configs = {
"dp_degree": 1, "dp_degree": 1,
"mp_degree": self.rank, "mp_degree": self.ranks,
"pp_degree": 1, "pp_degree": 1,
"sharding_degree": 1, "sharding_degree": 1,
} }
@@ -308,10 +318,19 @@ class PaddleDisWorkerProc():
# Local rank # Local rank
self.local_rank = fleet.worker_index() self.local_rank = fleet.worker_index()
return self.rank, self.local_rank return self.ranks, self.local_rank
def determine_num_available_blocks(self): def determine_num_available_blocks(self) -> None:
""" """Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
""" """
if self.fd_config.parallel_config.do_profile: if self.fd_config.parallel_config.do_profile:
# 1. Get available memory(bytes) # 1. Get available memory(bytes)
@@ -343,7 +362,8 @@ class PaddleDisWorkerProc():
) )
# 3. Send IPCSignal # 3. Send IPCSignal
get_profile_block_num = np.zeros(shape=[self.rank], dtype=np.int32) get_profile_block_num = np.zeros(shape=[self.ranks],
dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal( self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num", name="get_profile_block_num",
array=get_profile_block_num, array=get_profile_block_num,
@@ -366,12 +386,12 @@ class PaddleDisWorkerProc():
# 4. Updata share inputs # 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global) self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global)
def init_device(self): def init_device(self) -> None:
""" """ """ Initialize device and Construct model runner """
self.worker.init_device() self.worker.init_device()
def load_model(self): def load_model(self) -> None:
""" """ """ Load weights and create model """
self.worker.load_model() self.worker.load_model()
@@ -428,9 +448,6 @@ def parse_args():
parser.add_argument("--do_profile", parser.add_argument("--do_profile",
action='store_true', action='store_true',
help="do profile or not") help="do profile or not")
parser.add_argument("--dynamic_load_weight",
action='store_true',
help="dynamic load weight or not")
parser.add_argument("--pad_token_id", parser.add_argument("--pad_token_id",
type=int, type=int,
default=-1, default=-1,
@@ -467,14 +484,6 @@ def parse_args():
default="WINT8", default="WINT8",
type=str, type=str,
) )
parser.add_argument(
"--attention_backend",
default="APPEND_ATTN",
type=str,
choices=[
"APPEND_ATTN",
],
)
parser.add_argument("--max_num_batched_tokens", parser.add_argument("--max_num_batched_tokens",
type=int, type=int,
default=2048, default=2048,
@@ -527,11 +536,26 @@ def parse_args():
parser.add_argument("--disable_any_whitespace", parser.add_argument("--disable_any_whitespace",
action='store_false', action='store_false',
help="Disable any whitespace for guided decoding.") help="Disable any whitespace for guided decoding.")
parser.add_argument("--dynamic_load_weight",
action='store_true',
help="Enable dynamic weight loading strategy")
parser.add_argument(
"--load_strategy",
type=str,
choices=['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta', 'normal'],
default='meta',
help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, "
"'ipc_no_reshard': IPC streaming without weight processing, "
"'ipc_snapshot': load from disk snapshot of IPC weights, "
"'meta': provide RL traing worker, no_weights_load"
"'normal':normal load weight")
args = parser.parse_args() args = parser.parse_args()
return args return args
def initialize_fd_config(args) -> FDConfig: def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
"""Initialize FDConfig """Initialize FDConfig
TODO(gongshaotian): Unified all configs to FDConfig TODO(gongshaotian): Unified all configs to FDConfig
""" """
@@ -554,7 +578,7 @@ def initialize_fd_config(args) -> FDConfig:
# model_config = ModelConfig() # model_config = ModelConfig()
decoding_config = DecodingConfig() decoding_config = DecodingConfig()
decoding_config = MoEConfig()
speculative_config = SpeculativeConfig() speculative_config = SpeculativeConfig()
parallel_config = ParallelConfig() parallel_config = ParallelConfig()
load_config = LoadConfig() load_config = LoadConfig()
@@ -592,7 +616,6 @@ def initialize_fd_config(args) -> FDConfig:
parallel_config.pad_token_id = args.pad_token_id parallel_config.pad_token_id = args.pad_token_id
parallel_config.eos_tokens_lens = args.eos_tokens_lens parallel_config.eos_tokens_lens = args.eos_tokens_lens
parallel_config.enable_chunked_prefill = args.enable_chunked_prefill parallel_config.enable_chunked_prefill = args.enable_chunked_prefill
parallel_config.attention_backend = args.attention_backend
parallel_config.max_num_batched_tokens = args.max_num_batched_tokens parallel_config.max_num_batched_tokens = args.max_num_batched_tokens
parallel_config.enable_prefix_caching = args.enable_prefix_caching parallel_config.enable_prefix_caching = args.enable_prefix_caching
@@ -600,6 +623,7 @@ def initialize_fd_config(args) -> FDConfig:
parallel_config.tensor_parallel_degree = args.tensor_parallel_size parallel_config.tensor_parallel_degree = args.tensor_parallel_size
parallel_config.expert_parallel_degree = args.expert_parallel_size parallel_config.expert_parallel_degree = args.expert_parallel_size
parallel_config.splitwise_role = args.splitwise_role parallel_config.splitwise_role = args.splitwise_role
load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
parallel_config.guided_decoding_backend = args.guided_decoding_backend parallel_config.guided_decoding_backend = args.guided_decoding_backend
parallel_config.disable_any_whitespace = args.disable_any_whitespace parallel_config.disable_any_whitespace = args.disable_any_whitespace
@@ -659,19 +683,20 @@ def initialize_fd_config(args) -> FDConfig:
moe_config.num_max_dispatch_tokens_per_rank = config.get( moe_config.num_max_dispatch_tokens_per_rank = config.get(
"num_max_dispatch_tokens_per_rank", 256) "num_max_dispatch_tokens_per_rank", 256)
moe_config.moe_use_aux_free = config.get("moe_use_aux_free", False)
model_config.ori_vocab_size = config.get("vocab_size", -1) model_config.ori_vocab_size = config.get("vocab_size", -1)
if "Ernie4_5_ForCausalLM" in config.get("architectures"): if "Ernie4_5_ForCausalLM" in config.get("architectures"):
model_config.ori_vocab_size = args.ori_vocab_size model_config.ori_vocab_size = args.ori_vocab_size
quantization_config = config.get("quantization_config", None) if "DeepseekV3ForCausalLM" in config.get("architectures"):
from paddleformers.transformers import AutoConfig
model_config.deepseekv3 = AutoConfig.from_pretrained(
args.model_name_or_path)
# Note(@wufeisheng): The `is_quantized` flag should be explicitly set to `true` #TODO(@yuanrisheng): kv_cache quant config can only be
# when the weights are actually quantized offline. For backward compatibility # stored in model config file, which should be unified
# with preview logic: quantization_config = config.get("quantization_config", None)
# - If `quantization_config` is provided but `is_quantized` is not explicitly set,
# the value of `is_quantized` will be determined by whether `kv_cache_quant_type`
# has been configured.
if not model_config.is_quantized: if not model_config.is_quantized:
if quantization_config is not None: if quantization_config is not None:
if "kv_cache_quant_type" not in quantization_config: if "kv_cache_quant_type" not in quantization_config:
@@ -689,9 +714,14 @@ def initialize_fd_config(args) -> FDConfig:
elif args.quantization != "None": elif args.quantization != "None":
quantization_config = {} quantization_config = {}
quant_config_name = args.quantization quant_config_name = args.quantization
if use_moe and quant_config_name == "wint4": quantization_config["quantization"] = quant_config_name
# use some trick code for ernie model and will unify it in future.
is_ernie = "Ernie4_5_ForCausalLM" in config.get("architectures") or \
"Ernie4_5_MoeForCausalLM" in config.get("architectures")
if use_moe and quant_config_name == "wint4" and is_ernie:
quantization_config["dense_quant_type"] = "wint8" quantization_config["dense_quant_type"] = "wint8"
quantization_config["moe_quant_type"] = "wint4" quantization_config["moe_quant_type"] = "wint4"
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant" quant_config_name = "mix_quant"
else: else:
quant_config_name = None quant_config_name = None
@@ -706,20 +736,26 @@ def initialize_fd_config(args) -> FDConfig:
if quant_config is not None: if quant_config is not None:
if model_config.is_quantized: if model_config.is_quantized:
logger.info( logger.info(
"=====The currently loaded model is an offline quantized model=====" "Model Status: Offline Quantized (pre-quantized weights loaded)"
) )
else: else:
logger.info("=====The currently loaded model is the original model\ logger.info(
The model will be quantized online=====") "Model Status: Original (will apply online quantization)")
logger.info(f"{json.dumps(quantization_config, indent=2)}")
logger.info(f"Quantization Method: {args.quantization or 'None'}")
else: else:
logger.info( logger.info(
"No quantization config found and use original weight and act dtype." "No quantization config found and use original weight and act dtype."
) )
logger.info("============================================")
model_config.architectures = config.get("architectures") model_config.architectures = config.get("architectures")
logger.info("===========load_config==============")
load_config.dynamic_load_weight = args.dynamic_load_weight
load_config.load_strategy = args.load_strategy
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")
fd_config = FDConfig(model_config=model_config, fd_config = FDConfig(model_config=model_config,
parallel_config=parallel_config, parallel_config=parallel_config,
speculative_config=speculative_config, speculative_config=speculative_config,
@@ -733,7 +769,7 @@ def initialize_fd_config(args) -> FDConfig:
return fd_config return fd_config
def run_worker_proc(): def run_worker_proc() -> None:
""" """
start worker process start worker process
""" """

View File

@@ -583,15 +583,14 @@ class XPUModelRunner(ModelRunnerBase):
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend( attn_cls = get_attention_backend()
self.parallel_config.attention_backend)
attn_backend = attn_cls(self.fd_config, attn_backend = attn_cls(self.fd_config,
kv_num_heads=self.model_config.kv_num_heads, kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads, num_heads=num_heads,
head_dim=head_dim) head_dim=head_dim)
if attn_backend is None: if attn_backend is None:
raise NotImplementedError( raise NotImplementedError(
f"{ self.parallel_config.attention_backend} attention backend is not support by XPUModelRunner" "Attention backend which you chose is not support by GPUModelRunner"
) )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)

View File

@@ -28,3 +28,4 @@ moviepy
triton==3.3 triton==3.3
use-triton-in-paddle use-triton-in-paddle
crcmod crcmod
fastsafetensors==0.1.14