diff --git a/build.sh b/build.sh index b7de6980a..0cf53b24b 100644 --- a/build.sh +++ b/build.sh @@ -166,7 +166,7 @@ function build_and_install() { echo -e "${BLUE}[install]${NONE} installing fastdeploy..." cd $DIST_DIR - ${python} -m pip install ./dist/fastdeploy*.whl --force-reinstall --no-cache-dir + find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir if [ $? -ne 0 ]; then cd .. echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed" @@ -228,6 +228,9 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then ${BLUE}fastdeploy branch:${NONE} $EFFLLM_BRANCH ($EFFLLM_COMMIT)\n" echo -e "${GREEN}wheel saved under${NONE} ${RED}${BOLD}./dist${NONE}" + + # install wheel + ${python} -m pip install ./dist/fastdeploy*.whl echo -e "${GREEN}wheel install success${NONE}\n" trap : 0 diff --git a/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh new file mode 100644 index 000000000..3ac80b6cc --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh @@ -0,0 +1,236 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + + +#include "multi_head_latent_attention_kernel.h" + +template +struct softmax_state_t { + AlignedVector o; + T m; + T d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = __float2half(-5e4f); + } else if constexpr (std::is_same::value) { + m = __float2bfloat16(-3.38953e38f); + } + } + + __device__ __forceinline__ softmax_state_t() { + init(); + } + + __device__ __forceinline__ void merge(const AlignedVector& other_o, + T other_m, + T other_d) { + // using kType = typename cascade_attn_nv_type2_traits::type; + T m_prev = m, d_prev = d; + m = m_prev > other_m ? m_prev : other_m; + T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m); + + d = d_prev * scale1 + other_d * scale2; + +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1 + other_o[i] * scale2; + } + } + + __device__ __forceinline__ void normalize() { + +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d; + } + } + +}; + +template +struct softmax_state_ts { + uint32_t num_tiles_ = num_tiles; + AlignedVector o[num_tiles]; + float m; + float d; + + __device__ __forceinline__ void init() { +#pragma unroll + for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o[tile_id]) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0); + } + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + __device__ __forceinline__ softmax_state_ts() { + init(); + } + + __device__ __forceinline__ void normalize(const uint32_t tile_id) { + +#pragma unroll + for (size_t i = 0; i < vec_size; i++) { + o[tile_id][i] /= d; + } + } + +}; + +template +__device__ __forceinline__ void produce_kv(CacheT *smem, + CacheT *kv_base_gptr, + const int * block_table_smem, + const uint32_t seq_offset_gmem, + const uint32_t seq_offset_smem, + const uint32_t kv_head_idx, + const uint32_t kv_num_heads, + const uint32_t tidx, + const uint32_t chunk_start, + const uint32_t chunk_end) { + int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; + } + const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE; + // 8/16 T/int8 each time + const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK; + const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK; + for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>( + smem + smem_offset_base + vid * CACHE_VEC_SIZE, + kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE, + seq_offset_gmem < chunk_end + ); + } +} + +template +__device__ __forceinline__ void compute_qk(const T* cu_q_smem, + const CacheT* k_smem, + const uint32_t kv_idx_base, + const uint32_t stage_idx, + const uint32_t iter_base, + const uint32_t iter_bound, + const uint32_t tidx, + const uint32_t gid, + const float scale, + float *s, + softmax_state_ts& st) { + const CacheT* smem; + AlignedVector q_vec; + AlignedVector k_vec; + float m_prev = st.m; + // smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM; + smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM; +#pragma unroll + for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) { + if (iter_base + j < iter_bound) { + if constexpr (std::is_same::value) { + s[j] = 0.f; + } else if constexpr (std::is_same::value) { + s[j] = 0.f; + } +#pragma unroll + for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + Load(cu_q_smem + vid * vec_size, &q_vec); + Load(smem + j * HEAD_DIM + vid * vec_size, &k_vec); + for (uint32_t i = 0; i < vec_size; ++i) { + s[j] += static_cast(q_vec[i] * k_vec[i]); + } + } +#pragma unroll + for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { + s[j] += __shfl_xor_sync(-1, s[j], offset, 32); + } + __syncthreads(); + } else { + if constexpr (std::is_same::value) { + s[j] = -5e4f; + } else if constexpr (std::is_same::value) { + s[j] = -3.38953e38f; + } + } + st.m = st.m > s[j] ? st.m : s[j]; + } + + // T o_scale = hexp(m_prev - st.m); + float o_scale = __expf(m_prev - st.m); + st.d *= o_scale; + +#pragma unroll + for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) { + // s[j] = hexp(s[j] - st.m); + s[j] = __expf(s[j] - st.m); + st.d += s[j]; + } +#pragma unroll + for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) { + for (uint32_t i = 0; i < vec_size; ++i) { + st.o[tile_id][i] *= o_scale; + } + } +} + +template +__device__ __forceinline__ void compute_sv(const float *s, + const CacheT *base_v_smem, + const uint32_t stage_idx, + const uint32_t iter_base, + const uint32_t iter_bound, + const uint32_t tidx, + softmax_state_ts& st) { + const CacheT* v_smem; + AlignedVector v_vec; +#pragma unroll + for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) { + v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK; + for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + Load(v_smem + vid * vec_size, &v_vec); + uint32_t tile_id = vid / bdx; +#pragma unroll + for (int reg_id = 0; reg_id < vec_size; ++reg_id) { + st.o[tile_id][reg_id] += static_cast(s[j]) * v_vec[reg_id]; + } + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu b/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu new file mode 100644 index 000000000..2341f7284 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu @@ -0,0 +1,560 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decode_attention_func.cuh" + +#define CHECK(call) \ +do \ +{ \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) \ + { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line %d:\n", __LINE__); \ + printf(" Error code:%d\n", error_code); \ + printf(" Error text:%s\n", cudaGetErrorString(error_code)); \ + exit(1); \ + } \ +}while(0) + +template +__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim] + const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads] + const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads] + const int * __restrict__ seq_lens_q, + const int * __restrict__ seq_lens_kv, + const int * __restrict__ cum_offsets, + const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + OutT * __restrict__ out, // [token_num, num_heads, head_dim] + const float in_scale, + const int num_chunks, + const int chunk_size, + const int max_seq_len, + const int num_heads, + const int head_dim) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int qid = blockIdx.x, hid = blockIdx.y; + const int seq_len_q = seq_lens_q[qid]; + if (seq_len_q == 0) return; + int seq_len_kv = seq_lens_kv[qid]; + if (seq_len_kv == 0) return; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); + if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) { + return; + } + __shared__ T smem[bdy * HEAD_DIM]; + __shared__ T md_smem[bdy * 2]; + + const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]); + using LoadT = AlignedVector; + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + T m; + T d = 1.f; + if constexpr (std::is_same::value) { + m = __float2half(-5e4f); + } else if constexpr (std::is_same::value) { + m = __float2bfloat16(-3.38953e38f); + } + // merge per ty +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset = (qid * num_chunks + i) * num_heads + hid; + T m_prev = m; + T d_prev = d; + const T m_now = multi_m[offset]; + const T d_now = multi_d[offset]; + m = m_prev > m_now ? m_prev : m_now; + offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m); + d = d * scale1 + d_now * scale2; +#pragma once + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2; + } + } + // store ty res + Store(res_vec, &smem[ty * head_dim + vid * vec_size]); + md_smem[2 * ty] = m; + md_smem[2 * ty + 1] = d; + __syncthreads(); + + // merge bdy + softmax_state_t st{}; + const uint32_t iter_num = min(num_chunks_this_seq, bdy); +#pragma once + for (int i = 0; i < iter_num; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + st.normalize(); + + AlignedVector out_vec; + +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + out_vec[i] = static_cast(st.o[i]); + } + Store(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]); +} + +template +__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim] + CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim] + CacheT * __restrict__ cache_v, + const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int * __restrict__ seq_lens_q, + const int * __restrict__ seq_lens_kv, + const int * __restrict__ cum_offsets, + const int * __restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim] + T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads] + T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads] + OutT * __restrict__ out) { + const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t bid = bidx, gid = threadIdx.y; + const uint32_t tidx = threadIdx.x; + constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE; + constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE; + constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx; + + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + + const int *block_table_now = block_table + bid * max_block_num_per_seq; + + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_id = blockIdx.y; + const uint32_t q_len = seq_lens_q[bid]; + if (q_len <= 0) { + return; + } + uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!! + if (kv_len <= 0) { + return; + } + kv_len += q_len; + const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size); + const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]); + if (chunk_id >= num_chunk_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0; + const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK; + T *q_smem = reinterpret_cast(smem); // [HEAD_DIM_QK * sizeof(T)] + T *cu_q_smem = q_smem + gid * HEAD_DIM_QK; +#pragma unroll + for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { + ((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0]; + + } + __syncthreads(); + using VecT = AlignedVector; + VecT q_vec; +#pragma unroll + for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { + Load(cu_q_smem + vid * VEC_SIZE, &q_vec); + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + q_vec[i] *= scale; + } + Store(q_vec, cu_q_smem + vid * VEC_SIZE); + } + + + CacheT *kv_smem = reinterpret_cast(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT)); + uint32_t stage_idx = 0; + constexpr int loop_times = DEAL_EACH_TIME / bdy; +#pragma unroll + for (int i = 0; i < NUM_STAGES; ++i) { +#pragma unroll + for (int j = 0; j < loop_times; ++j) { + const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid; + const uint32_t k_seq_id = chunk_start + k_seq_offset; + produce_kv( + kv_smem, + cache_k, + block_table_now, + k_seq_id, + k_seq_offset, + kv_head_idx, + kv_num_heads, + tidx, + chunk_start, + chunk_end + ); + } + commit_group(); + stage_idx = (stage_idx + 1) % NUM_STAGES; + } + + + softmax_state_ts st; + float s[DEAL_EACH_TIME]; + + const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME); + for (int iter = 0; iter < num_iters; ++iter) { + wait_group(); + __syncthreads(); + // compute qk + compute_qk( + cu_q_smem, + kv_smem, + chunk_start + iter * DEAL_EACH_TIME, + stage_idx, + iter * DEAL_EACH_TIME, + chunk_len, + tidx, + gid, + scale, + s, + st + ); + __syncthreads(); + + // compute sv + compute_sv( + s, + kv_smem, + stage_idx, + iter * DEAL_EACH_TIME, + chunk_len, + tidx, + st + ); + __syncthreads(); + +#pragma unroll + for (int j = 0; j < loop_times; ++j) { + const uint32_t k_seq_offset = j * bdy + gid; + produce_kv( + kv_smem, + cache_k, + block_table_now, + chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME, + stage_idx * DEAL_EACH_TIME + k_seq_offset, + kv_head_idx, + kv_num_heads, + tidx, + chunk_start, + chunk_end + ); + } + commit_group(); + stage_idx = (stage_idx + 1) % NUM_STAGES; + } + wait_group<0>(); + __syncthreads(); + + // normize if not partition_kv + for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) { + const uint32_t tile_id = vid / bdx; + if (!partition_kv || num_chunk_this_seq == 1) { + st.normalize(tile_id); + } + if (partition_kv && num_chunk_this_seq > 1) { + const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx; + Store(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE); + tmp_m[head_idx] = st.m; + tmp_d[head_idx] = st.d; + } else { + Store(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE); + } + } +} + + +template +void MultiQueryDecoderAttention( + const AppendAttnMetaData& meta_data, + cudaStream_t &stream, + const paddle::Tensor &q, + const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &padding_offsets, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &block_table, + const int max_seq_len, + const int max_dec_len, + const float rope_scale, + const float rope_theta, + const float softmax_scale, + const float in_scale, + paddle::Tensor *out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + constexpr int num_stages = NUM_STAGE; + + constexpr int vec_size = 16 / sizeof(T); // 8 16 32 + constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32 + constexpr int blockxc = HEAD_DIM_QK / cache_vec_size; + constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size; + constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32; + + constexpr int blocky = GROUP_SIZE; + const int gridx = bsz; + + constexpr int num_threads = blockx * blocky; + + auto splitkv_kernel = multi_query_decode_attention_kernel; + uint32_t cache_smem_bytes = 0; + + const T *shift_bias_ptr = shift_bias ? shift_bias.get().data() : nullptr; + const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data() : nullptr; + cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T); + + const uint32_t chunk_size = get_max_partition_size(bsz); + const int num_chunks = div_up(max_dec_len, chunk_size); + size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T); + + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, splitkv_kernel, num_threads, smem_size); + assert(act_blocks_per_sm > 1); + + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = gridx * num_chunks * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); + + dim3 grids(gridx, num_chunks, kv_num_heads); + dim3 blocks(blockx, blocky); + if (num_chunks <= 1) { + auto no_splitkv_kernel = multi_query_decode_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + no_splitkv_kernel<<>>( + reinterpret_cast(const_cast(q.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + seq_lens_q.data(), + seq_lens_kv.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + softmax_scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(const_cast(out->data())) + ); + + // CHECK(cudaGetLastError()); + // CHECK(cudaDeviceSynchronize()); + } else { + auto *allocator = paddle::GetAllocator(q.place()); + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + tmp_workspace = allocator->Allocate( + phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM_V)); + tmp_m = allocator->Allocate( + phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads)); + + splitkv_kernel<<>>( + reinterpret_cast(const_cast(q.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + seq_lens_q.data(), + seq_lens_kv.data(), + cum_offsets.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + softmax_scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + reinterpret_cast(tmp_m->ptr()), + reinterpret_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(out->data())) + ); + // CHECK(cudaGetLastError()); + // CHECK(cudaDeviceSynchronize()); + + constexpr int mblockx = HEAD_DIM_V / vec_size; + constexpr int bdy = 256 / mblockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(mblockx, bdy); + merge_varlen_multi_chunks_v2_kernel<<>>( + reinterpret_cast(tmp_workspace->ptr()), + reinterpret_cast(tmp_m->ptr()), + reinterpret_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + cum_offsets.data(), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + reinterpret_cast(const_cast(out->data())), + in_scale, + num_chunks, + chunk_size, + max_seq_len, + num_heads, + HEAD_DIM_V + ); + } + // CHECK(cudaGetLastError()); + // CHECK(cudaDeviceSynchronize()); +} + +template +void DecodeMLAAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &padding_offsets, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out) { + const auto token_num = meta_data.token_nums; + const auto block_size = meta_data.block_size; + const auto bsz = meta_data.batch_size; + const auto num_heads = meta_data.q_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + const auto head_dim_qk = meta_data.head_dims; + const auto head_dim_v = meta_data.head_dims_v; + const float rope_scale = 0.0; + const float rope_theta = 0.0; + const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); + const uint32_t num_stage = get_cascade_attention_num_stages(); + const uint32_t num_threads = get_cascade_attention_num_threads(); + + DISPATCH_CAUSAL(causal, CAUSAL, + {DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, + {DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK, + {DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V, + {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, + {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, + {MultiQueryDecoderAttention( + meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets, + block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})}); +} + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &padding_offsets, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &padding_offsets, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu new file mode 100644 index 000000000..a1c3cee22 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu @@ -0,0 +1,291 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "mla_cache_kernel.cuh" + +template +std::vector PrefillMLAWriteCache( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + auto num_tokens = meta_data.token_nums; + auto block_size = meta_data.block_size; + auto nope_size = meta_data.head_dims_v; + auto all_size = meta_data.head_dims; + int pe_size = all_size - nope_size; + auto kv_num_heads = meta_data.kv_num_heads; + const uint32_t elem_nums = num_tokens * kv_num_heads * all_size; + + constexpr int PackSize = 16 / sizeof(DataType_); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + + prefill_absorb_cache_kernel + <<>>( + reinterpret_cast(const_cast(kv_nope.data())), + reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast(kv_cache->data()), + block_tables.data(), + padding_offsets.data(), + cum_offsets.data(), + seq_lens.data(), + seq_lens_decoder.data(), + max_seq_len, + max_blocks_per_seq, + kv_num_heads, + nope_size, + pe_size, + block_size, + elem_nums); + return {}; +} + +std::vector PrefillMLAWriteCacheKernel( + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& kv_cache, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const std::string& cache_quant_type_str, + const int max_seq_len) { + cudaStream_t stream = kv_pe.stream(); + AppendAttnMetaData meta_data; + const auto& kv_nope_dims = kv_nope.dims(); + const auto& kv_pe_dims = kv_pe.dims(); + const auto& kv_cache_dims = kv_cache.dims(); + meta_data.kv_num_heads = kv_cache_dims[1]; + const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; + meta_data.token_nums = kv_nope_dims[0]; + meta_data.head_dims = kv_cache_dims[3]; + meta_data.head_dims_v = nope_size; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = kv_cache_dims[2]; + meta_data.batch_size = cum_offsets.dims()[0]; + switch (kv_pe.dtype()) { + case paddle::DataType::BFLOAT16: { + return PrefillMLAWriteCache(meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_decoder, + padding_offsets, + cum_offsets, + block_tables, + max_seq_len, + stream, + const_cast(&kv_cache)); + } + case paddle::DataType::FLOAT16: { + return PrefillMLAWriteCache(meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_decoder, + padding_offsets, + cum_offsets, + block_tables, + max_seq_len, + stream, + const_cast(&kv_cache)); + } + } + return {}; +} + +template +std::vector DecodeMLAWriteCache( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const int max_seq_len, + const bool speculate_decoder, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + auto bsz = meta_data.batch_size; + auto token_num = meta_data.token_nums; + auto block_size = meta_data.block_size; + auto nope_size = meta_data.head_dims_v; + auto all_size = meta_data.head_dims; + int pe_size = all_size - nope_size; + auto kv_num_heads = meta_data.kv_num_heads; + constexpr int PackSize = 16 / sizeof(DataType_); + const int blocksize = 128; + int grid_size = 1; + + + if (speculate_decoder) { + const uint32_t elem_nums = token_num * kv_num_heads * all_size; + const int pack_num = elem_nums / PackSize; + GetNumBlocks<128>(pack_num, &grid_size); + speculate_decode_absorb_cache_kernel + <<>>( + reinterpret_cast(const_cast(kv_nope.data())), + reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast(kv_cache->data()), + block_tables.data(), + padding_offsets.data(), + cum_offsets.data(), + seq_lens.data(), + seq_lens_encoder.data(), + max_seq_len, + max_blocks_per_seq, + kv_num_heads, + nope_size, + pe_size, + block_size, + elem_nums); + } else { + const uint32_t elem_nums = bsz * kv_num_heads * all_size; + const int pack_num = elem_nums / PackSize; + GetNumBlocks<128>(pack_num, &grid_size); + decode_absorb_cache_kernel + <<>>( + reinterpret_cast(const_cast(kv_nope.data())), + reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast(kv_cache->data()), + block_tables.data(), + cum_offsets.data(), + seq_lens.data(), + seq_lens_encoder.data(), + max_seq_len, + max_blocks_per_seq, + kv_num_heads, + nope_size, + pe_size, + block_size, + elem_nums); + } + return {}; +} + +std::vector DecodeMLAWriteCacheKernel( + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& kv_cache, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const std::string& cache_quant_type_str, + const int max_seq_len, + const bool speculate_decoder) { + cudaStream_t stream = kv_pe.stream(); + AppendAttnMetaData meta_data; + const auto& kv_nope_dims = kv_nope.dims(); + const auto& kv_pe_dims = kv_pe.dims(); + const auto& kv_cache_dims = kv_cache.dims(); + meta_data.kv_num_heads = kv_cache_dims[1]; + const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; + meta_data.token_nums = kv_nope_dims[0]; + meta_data.head_dims = kv_cache_dims[3]; + meta_data.head_dims_v = nope_size; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = kv_cache_dims[2]; + meta_data.batch_size = cum_offsets.dims()[0]; + switch (kv_pe.dtype()) { + case paddle::DataType::BFLOAT16: { + return DecodeMLAWriteCache(meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + max_seq_len, + speculate_decoder, + stream, + const_cast(&kv_cache)); + } + case paddle::DataType::FLOAT16: { + return DecodeMLAWriteCache(meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + max_seq_len, + speculate_decoder, + stream, + const_cast(&kv_cache)); + } + } + return {}; +} + + +PD_BUILD_OP(prefill_mla_write_cache) + .Inputs({"kv_nope", + "kv_pe", + "kv_cache", + "seq_lens", + "seq_lens_decoder", + "padding_offsets", + "cum_offsets", + "block_tables"}) + .Outputs({"kv_cache_out"}) + .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) + .Attrs({"cache_quant_type_str: std::string", + "max_seq_len: int"}) + .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel)); + +PD_BUILD_OP(decode_mla_write_cache) + .Inputs({"kv_nope", + "kv_pe", + "kv_cache", + "seq_lens", + "seq_lens_encoder", + "padding_offsets", + "cum_offsets", + "block_tables"}) + .Outputs({"kv_cache_out"}) + .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) + .Attrs({"cache_quant_type_str: std::string", + "max_seq_len: int", + "speculate_decoder: bool"}) + .SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel)); diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh new file mode 100644 index 000000000..b0ab79d95 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh @@ -0,0 +1,242 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" +#include "mem_util.cuh" +#include "utils.cuh" + +template +__global__ void decode_absorb_cache_kernel( + const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const int max_seq_len, + const int max_blocks_per_seq, + const int kv_num_heads, + const int nope_size, + const int pe_size, + const int block_size, + const uint32_t elem_cnt) { + using LoadT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + LoadT src_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t nope_hidden_size = kv_num_heads * nope_size; + const uint32_t pe_hidden_size = kv_num_heads * pe_size; + const uint32_t all_size = nope_size + pe_size; + const int64_t hidden_size = nope_hidden_size + pe_hidden_size; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int ori_bi = linear_index / hidden_size; + const int bias = linear_index % hidden_size; + const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + + if (bias < nope_hidden_size) { // pe + const uint32_t inner_bias = bias; + const uint32_t hi = inner_bias / nope_size; + const uint32_t h_bias = inner_bias % nope_size; + const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + + block_offset * all_size + h_bias; + const uint32_t ori_idx = + start_token_idx * nope_hidden_size + inner_bias; + Load(&kv_nope[ori_idx], &src_vec); + Store(src_vec, &kv_cache[tgt_idx]); + } else { + const uint32_t inner_bias = bias - nope_hidden_size; + const uint32_t hi = inner_bias / pe_size; + const uint32_t h_bias = inner_bias % pe_size; + const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + + block_offset * all_size + nope_size + h_bias; + const uint32_t ori_idx = + start_token_idx * pe_hidden_size + inner_bias; + Load(&kv_pe[ori_idx], &src_vec); + Store(src_vec, &kv_cache[tgt_idx]); + } + } +} + +template +__global__ void speculate_decode_absorb_cache_kernel( + const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const int max_seq_len, + const int max_blocks_per_seq, + const int kv_num_heads, + const int nope_size, + const int pe_size, + const int block_size, + const uint32_t elem_cnt) { + using LoadT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + LoadT src_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t nope_hidden_size = kv_num_heads * nope_size; + const uint32_t pe_hidden_size = kv_num_heads * pe_size; + const uint32_t all_size = nope_size + pe_size; + const int64_t hidden_size = nope_hidden_size + pe_hidden_size; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_id = linear_index / hidden_size; + const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len; + if (seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % hidden_size; + const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi]; + const int write_seq_id = + seq_lens[ori_bi] + token_id - start_token_idx; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + if (block_idx < 0) { + printf( + "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " + "%d %d %d %d\n", + block_idx, + write_seq_id, + ori_bi, + seq_lens[ori_bi], + token_id, + cum_offsets[ori_bi]); + } + if (bias < nope_hidden_size) { // pe + const uint32_t inner_bias = bias; + const uint32_t hi = inner_bias / nope_size; + const uint32_t h_bias = inner_bias % nope_size; + const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + + block_offset * all_size + h_bias; + const uint32_t ori_idx = + token_id * nope_hidden_size + inner_bias; + Load(&kv_nope[ori_idx], &src_vec); + Store(src_vec, &kv_cache[tgt_idx]); + } else { + const uint32_t inner_bias = bias - nope_hidden_size; + const uint32_t hi = inner_bias / pe_size; + const uint32_t h_bias = inner_bias % pe_size; + const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + + block_offset * all_size + nope_size + h_bias; + const uint32_t ori_idx = + token_id * pe_hidden_size + inner_bias; + Load(&kv_pe[ori_idx], &src_vec); + Store(src_vec, &kv_cache[tgt_idx]); + } + } +} + +template +__global__ void prefill_absorb_cache_kernel( + const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ padding_offsets, + const int* __restrict__ cum_offsets, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_decoder, // [bsz] + const int max_seq_len, + const int max_blocks_per_seq, + const int kv_num_heads, + const int nope_size, + const int pe_size, + const int block_size, + const uint32_t elem_cnt) { + using LoadT = AlignedVector; + LoadT src_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t nope_hidden_size = kv_num_heads * nope_size; + const uint32_t pe_hidden_size = kv_num_heads * pe_size; + const uint32_t all_size = nope_size + pe_size; + const int64_t hidden_size = nope_hidden_size + pe_hidden_size; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const uint32_t token_idx = linear_index / hidden_size; + const uint32_t bias = linear_index % hidden_size; + const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx]; + const uint32_t ori_bi = ori_token_idx / max_seq_len; + if (seq_lens[ori_bi] == 0) continue; + const uint32_t ori_seq_id = + ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi]; + + const int* block_table_now = nullptr; + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const uint32_t block_idx = block_table_now[ori_seq_id / block_size]; + const uint32_t block_offset = ori_seq_id % block_size; + + if (bias < nope_hidden_size) { // pe + const uint32_t inner_bias = bias; + const uint32_t hi = inner_bias / nope_size; + const uint32_t h_bias = inner_bias % nope_size; + const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + + block_offset * all_size + h_bias; + const uint32_t ori_idx = + token_idx * nope_hidden_size + inner_bias; + Load(&kv_nope[ori_idx], &src_vec); + Store(src_vec, &kv_cache[tgt_idx]); + } else { + const uint32_t inner_bias = bias - nope_hidden_size; + const uint32_t hi = inner_bias / pe_size; + const uint32_t h_bias = inner_bias % pe_size; + const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + + block_offset * all_size + nope_size + h_bias; + const uint32_t ori_idx = + token_idx * pe_hidden_size + inner_bias; + Load(&kv_pe[ori_idx], &src_vec); + Store(src_vec, &kv_cache[tgt_idx]); + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h b/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h new file mode 100644 index 000000000..013621bd2 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "helper.h" +#include "utils.cuh" + +template +void DecodeMLAAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional& attn_mask, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &padding_offsets, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 5be300177..05f500126 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -25,6 +25,7 @@ struct AppendAttnMetaData { int kv_num_heads; int token_nums; int head_dims; + int head_dims_v; int max_blocks_per_seq; }; @@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast( } \ } -#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \ - if (num_stage == 2) { \ - constexpr size_t NUM_STAGE = 2; \ - __VA_ARGS__ \ +#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 192: { \ + constexpr size_t HEAD_DIM = 192; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim: ", head_dim); \ + } \ + } + +#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 192: { \ + constexpr size_t HEAD_DIM = 192; \ + __VA_ARGS__ \ + break; \ + } \ + case 512: { \ + constexpr size_t HEAD_DIM = 512; \ + __VA_ARGS__ \ + break; \ + } \ + case 576: { \ + constexpr size_t HEAD_DIM = 576; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim: ", head_dim); \ + } \ + } + +#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \ + if (num_stage == 2) { \ + constexpr size_t NUM_STAGE = 2; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the num_stage: ", num_stage); \ } #define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \ @@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast( constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \ constexpr size_t cache_bytes = 4; \ __VA_ARGS__ \ + } else { \ + PD_THROW("not support the cache_type: ", cache_type); \ } + #define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \ - if (deal_each_time == 32) { \ + if (deal_each_time == 32) { \ constexpr size_t DEAL_EACH_TIME = 32; \ __VA_ARGS__ \ } else if (deal_each_time == 64) { \ @@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast( PD_THROW("not support the group_size", group_size); \ } +#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else if (group_size == 128) { \ + constexpr size_t GROUP_SIZE = 128; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size: ", group_size); \ + } + #define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ if (block_shape_q <= 16) { \ constexpr size_t BLOCK_SHAPE_Q = 16; \ diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 60920b629..5eb56c14f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -316,6 +316,96 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input, paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids, int64_t num_experts); +void GetPositionIdsAndMaskEncoderBatch( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids, + const paddle::Tensor& mask_encoder_batch); + +std::vector DecodeMLAWriteCacheKernel( + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& kv_cache, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const std::string& cache_quant_type_str, + const int max_seq_len, + const bool speculate_decoder); + + std::vector PrefillMLAWriteCacheKernel( + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& kv_cache, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const std::string& cache_quant_type_str, + const int max_seq_len); + + +void FusedRotaryPositionEncoding( + paddle::Tensor& query, // [num_tokens, num_heads, head_size] or + // [num_tokens, num_heads * head_size] + paddle::Tensor& key, + // [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads * + // head_size] + const paddle::Tensor& position_ids, // [num_tokens] + const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim] + int head_size, + bool is_neox); + +std::vector MultiHeadLatentAttention( + const paddle::Tensor& query, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& max_dec_len_this_time, + const paddle::Tensor& max_len_kv, + const paddle::optional& attn_mask, + const paddle::optional& query_bias, + const paddle::optional& query_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const int nope_size, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder); std::vector tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M); @@ -370,6 +460,14 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input, paddle::Tensor &scales, float scale_ub); +std::vector NoauxTc( + paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + int n_group, + int topk_group, + int topk, + float routed_scaling_factor); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -627,6 +725,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("use_atomic_add"), py::arg("use_fp32_reduce"), py::arg("is_zp_float")); + m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch, + "get_position_ids_and_mask_encoder_batch function"); /** @@ -653,4 +753,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, "dynamic_per_token_scaled_fp8_quant function", py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); + m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function"); + + m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function"); + + m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function"); + + m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function"); + + m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); } diff --git a/custom_ops/gpu_ops/env.h b/custom_ops/gpu_ops/env.h new file mode 100644 index 000000000..c7db21ba8 --- /dev/null +++ b/custom_ops/gpu_ops/env.h @@ -0,0 +1,64 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +inline uint32_t get_decoder_block_shape_q() { + static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q"); + static const uint32_t decoder_block_shape_q = + decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env)); + return decoder_block_shape_q; +} + +inline uint32_t get_encoder_block_shape_q() { + static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q"); + static const uint32_t encoder_block_shape_q = + encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env)); + return encoder_block_shape_q; +} + +inline uint32_t get_max_partition_size(int bsz) { + static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size"); + static const uint32_t max_partition_size = + max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env)); + return max_partition_size; +} + +inline uint32_t get_cascade_attention_deal_each_time() { + static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time"); + static const uint32_t cascade_attention_deal_each_time = + cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env)); + return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32); +} + +inline uint32_t get_cascade_attention_num_stages() { + static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages"); + static const uint32_t cascade_attention_num_stages = + cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env)); + return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2; +} + +inline uint32_t get_cascade_attention_num_threads() { + static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads"); + static const uint32_t cascade_attention_num_threads = + cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env)); + return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128; +} + +inline bool get_mla_use_tensorcore() { + static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore"); + static const uint32_t mla_use_tensorcore = + mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env)); + return mla_use_tensorcore != 0 ? true : false; +} diff --git a/custom_ops/gpu_ops/fused_rotary_position_encoding.cu b/custom_ops/gpu_ops/fused_rotary_position_encoding.cu new file mode 100644 index 000000000..41670fec8 --- /dev/null +++ b/custom_ops/gpu_ops/fused_rotary_position_encoding.cu @@ -0,0 +1,146 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +template +inline __device__ void apply_token_rotary_embedding_kernel( + T* __restrict__ arr, + const T* __restrict__ cos_ptr, + const T* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + T cos, sin; + if (IS_NEOX) { + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = cos_ptr[x_index]; + sin = sin_ptr[x_index]; + } else { + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = cos_ptr[x_index / 2]; + sin = sin_ptr[x_index / 2]; + } + + const T x = arr[x_index]; + const T y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + + +template +__global__ void apply_rotary_embedding_kernel( + T* __restrict__ query, // [num_tokens, num_heads, head_size] + T* __restrict__ key, // [num_tokens, num_kv_heads, head_size] + const int* __restrict__ position_ids, // [num_tokens] + const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int pos = position_ids[token_idx]; + const T* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const T* cos_ptr = cache_ptr; + const T* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding_kernel( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding_kernel( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } +} + + +void FusedRotaryPositionEncoding( + paddle::Tensor& query, // [num_tokens, num_heads, head_size] or + // [num_tokens, num_heads * head_size] + paddle::Tensor& key, + // [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads * + // head_size] + const paddle::Tensor& position_ids, // [num_tokens] + const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim] + int head_size, + bool is_neox) { + int64_t num_tokens = query.dims()[0]; + int num_heads = query.numel() / num_tokens / head_size; + int num_kv_heads = key.numel() / num_tokens / head_size; + int rot_dim = cos_sin_cache.dims()[1]; + int64_t query_stride = num_heads * head_size; + int64_t key_stride = num_kv_heads * head_size; + + if (num_tokens > 65535) { + PD_THROW( + "apply_rotary_embedding_kernel launch failed when num_tokens > 65535."); + } + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + query.dtype(), "apply_rotary_embedding_kernel", [&] { + if (is_neox) { + apply_rotary_embedding_kernel + <<>>(query.data(), + key.data(), + position_ids.data(), + cos_sin_cache.data(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + apply_rotary_embedding_kernel + <<>>(query.data(), + key.data(), + position_ids.data(), + cos_sin_cache.data(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } + }); +} + +PD_BUILD_OP(fused_rotary_position_encoding) + .Inputs({"query", "key", "position_ids", "cos_sin_cache"}) + .Outputs({"query_out", "key_out"}) + .Attrs({"head_size: int", "is_neox: bool"}) + .SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}}) + .SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding)); diff --git a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu new file mode 100644 index 000000000..f58705d9f --- /dev/null +++ b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void GetPositionIdsAndMaskEncoderBatchKernel( + const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度 + const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度 + const int* seq_lens_this_time, + int* position_ids, // 输出的一维 position_ids + int* mask_encoder_batch, + const int bsz) { // 批次大小 + // 当前线程索引(每个线程对应一个批次) + int tid = threadIdx.x; + if (tid >= bsz) return; + + // 动态计算当前批次的偏移量 + int offset = 0; + for (int i = 0; i < tid; i++) { + offset += seq_lens_encoder[i]; + if (seq_lens_decoder[i] > 0) { + offset += seq_lens_this_time[i]; + } + } + + // 当前批次的 encoder 和 decoder 长度 + int encoder_len = seq_lens_encoder[tid]; + int decoder_len = seq_lens_decoder[tid]; + int seq_len_this_time = seq_lens_this_time[tid]; + + // 写入 encoder 的 position_ids + for (int i = 0; i < encoder_len; i++) { + position_ids[offset + i] = i; + mask_encoder_batch[offset + i] = 1; + } + offset += encoder_len; + + // 写入 decoder 的 position_ids + if (decoder_len > 0) { + for (int i = 0; i < seq_len_this_time; i++) { + position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身 + mask_encoder_batch[offset + i] = 0; + } + } +} + + +void GetPositionIdsAndMaskEncoderBatch( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids, + const paddle::Tensor& mask_encoder_batch) { + const int bsz = seq_lens_this_time.shape()[0]; + + GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>( + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + const_cast(position_ids.data()), + const_cast(mask_encoder_batch.data()), + bsz); +} + +PD_BUILD_OP(get_position_ids_and_mask_encoder_batch) + .Inputs({"seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "position_ids", + "mask_encoder_batch"}) + .Outputs({"position_ids_out", "mask_encoder_batch_out"}) + .SetInplaceMap({{"position_ids", "position_ids_out"}, + {"mask_encoder_batch", "mask_encoder_batch_out"}}) + .SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch)); diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index ab56ac144..f829bf1ff 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -39,10 +39,12 @@ namespace cub = hipcub; #include #include +#include "env.h" #include "paddle/extension.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/cuda_stream.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -513,3 +515,10 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { cudaDevAttrMaxSharedMemoryPerBlockOptin, device); return max_shared_mem_per_block_opt_in; } + +inline int GetSMVersion() { + static int sm_version = phi::backends::gpu::GetGPUComputeCapability( + phi::backends::gpu::GetCurrentDeviceId()); + return sm_version; + +} diff --git a/custom_ops/gpu_ops/mla_attn/attention_updater.cuh b/custom_ops/gpu_ops/mla_attn/attention_updater.cuh new file mode 100644 index 000000000..49f8089d3 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/attention_updater.cuh @@ -0,0 +1,255 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ + +#include +#include + +#include "utils.cuh" + +namespace mla_attn { + +using namespace cute; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } +}; + +template +struct SumOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template <> +struct Allreduce<2> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template +__device__ __forceinline__ void thread_reduce_(Tensor const& tensor, + Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor& dst, + Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, + Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, + Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, + Tensor& sum) { + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { + quad_allreduce_(sum, sum, sum_op); + } +} + +template +__forceinline__ __device__ void apply_exp2(Tensor& tensor, + Tensor const& max) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + auto row_max = max(mi); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = __expf(tensor(mi, ni) - row_max); + } + } +} + +template +__forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, + Tensor const& max, + const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + auto row_max = max(mi); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // row_max * scale is a constant for each row, so we can use fma here + tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale); + } + } +} + +template +struct OnlineSoftmax { + constexpr static float fill_value = -5e4; + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum, scores_scale; + float sm_scale_log2; + + CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) { + clear(scores_scale); + }; + + __forceinline__ __device__ TensorT get_lse() const { return row_sum; } + + template + __forceinline__ __device__ TensorT update(Tensor0& acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + if constexpr (init) { + reduce_max(scores, row_max); + if constexpr (WITH_SCALE) { + scale_apply_exp2(scores, row_max, sm_scale_log2); + } else { + apply_exp2(scores, row_max); + } + reduce_sum(scores, row_sum); + } else { + // update row_max + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + reduce_max(scores, row_max); + // update scores_scale and scale row_sum +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = row_max(mi); + if constexpr (WITH_SCALE) { + scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2); + } else { + scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur); + } + row_sum(mi) *= scores_scale(mi); + } + // perform exp2 on scores + if constexpr (WITH_SCALE) { + scale_apply_exp2(scores, row_max, sm_scale_log2); + } else { + apply_exp2(scores, row_max); + } + // update row_sum + reduce_sum(scores, row_sum); + return scores_scale; + } + }; + + template + __forceinline__ __device__ TensorT finalize(Tensor0& acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD); + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float sum = row_sum(mi); + float inv_sum = 1.f / sum; + scores_scale(mi) = inv_sum; + row_max(mi) *= sm_scale_log2; + } + return scores_scale; + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1& acc_o) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale(mi); + } + } + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale_input(mi); + } + } + }; +}; + +} // namespace mla_attn diff --git a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu new file mode 100644 index 000000000..40a9f6e56 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu @@ -0,0 +1,235 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "cute/tensor.hpp" +#include "mla_hopper.cuh" +#include +#include +#include + +#include "batch_mla_with_paged_kv_cache.h" +#include "env.h" + +using namespace cute; +using namespace mla_attn; +using namespace std; + +template +struct cascade_type_traits { + using type = T; + using cutlass_type = T; +}; +template <> +struct cascade_type_traits { + using type = __nv_bfloat16; + using cutlass_type = cutlass::bfloat16_t;; +}; +template <> +struct cascade_type_traits { + using type = half; + using cutlass_type = cutlass::half_t; +}; +template <> +struct cascade_type_traits { + using type = __nv_fp8_e4m3; + using cutlass_type = cutlass::float_e4m3_t; +}; + +template +void BatchMLAWithPagedKVCacheKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& q, // [token_num, q_head_num, head_dim] + const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim] + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const paddle::Tensor& num_blocks_x_device, + const std::string& cache_quant_type_str, + const int num_blocks_x, + const int max_seq_len, + const int max_dec_len, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int draft_token_num, + const bool causal, + cudaStream_t& stream, + paddle::Tensor* out) { + using NV_TYPE = typename cascade_type_traits::type; + using CUTLASS_TYPE = typename cascade_type_traits::cutlass_type; + const auto token_num = meta_data.token_nums; + const auto block_size = meta_data.block_size; + const auto bsz = meta_data.batch_size; + const auto q_head_num = meta_data.q_num_heads; + const auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + const auto max_block_num = bsz * max_block_num_per_seq; + const uint32_t chunk_size = get_max_partition_size(bsz); + + + int q_head_dim = meta_data.head_dims; + int k_head_dim = meta_data.head_dims; + int v_head_dim = meta_data.head_dims_v; + // int num_chunks = max_dec_len / chunk_size; + int num_chunks = div_up(max_dec_len, chunk_size); + + auto *allocator = paddle::GetAllocator(q.place()); + phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp; + O_tmp = allocator->Allocate( + phi::SizeOf(q.dtype()) * + static_cast(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim)); + m_tmp = allocator->Allocate( + sizeof(float) * + static_cast(num_chunks * bsz * draft_token_num * q_head_num)); + d_tmp = allocator->Allocate( + sizeof(float) * + static_cast(num_chunks * bsz * draft_token_num * q_head_num)); + + Params params = {}; + params.Q = reinterpret_cast(const_cast(q.data())); + params.KV = reinterpret_cast(const_cast(latent_cache.data())); + params.O = reinterpret_cast(const_cast(out->data())); + params.O_tmp = reinterpret_cast(O_tmp->ptr()); + params.m = reinterpret_cast(m_tmp->ptr()); + params.d = reinterpret_cast(d_tmp->ptr()); + params.block_tables = const_cast(block_tables.data()); + params.seq_lens_this_time = const_cast(seq_lens_this_time.data()); + params.seq_lens_encoder = const_cast(seq_lens_encoder.data()); + params.seq_lens_decoder = const_cast(seq_lens_decoder.data()); + params.cumsum_q_seqlens = const_cast(cu_seqlens_q.data()); + params.padding_offsets = const_cast(padding_offsets.data()); + params.batch_ids = const_cast(batch_ids.data()); + params.tile_ids_per_batch = const_cast(tile_ids_per_batch.data()); + params.num_blocks_x = const_cast(num_blocks_x_device.data()); + params.num_blocks_x_int = num_blocks_x; + params.q_stride_bsz = q_head_num * q_head_dim; + params.q_stride_head_num = q_head_dim; + params.kv_stride_block_num = block_size * k_head_dim; + params.kv_stride_block_size = k_head_dim; + params.o_stride_bsz = q_head_num * v_head_dim; + params.o_stride_head_num = v_head_dim; + params.bsz = bsz; + params.token_num = token_num; + params.max_seq_len = max_seq_len; + params.max_block_num = max_block_num; + params.max_block_num_per_seq = max_block_num_per_seq; + params.q_num_head = q_head_num; + params.qk_head_dim = q_head_dim; + params.vo_head_dim = v_head_dim; + params.block_size = block_size; + params.max_draft_token_num = draft_token_num; + params.sm_scale = softmax_scale; + params.chunk_size = chunk_size; + params.chunk_num = num_chunks; + + if (q_head_dim == 576) { + BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>( + params, stream + ); + } else { + PD_THROW("error!!! q_head_dim must be 576 !!!\n"); + } +} + +template void BatchMLAWithPagedKVCacheKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& q, // [token_num, q_head_num, head_dim] + const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim] + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const paddle::Tensor& num_blocks_x_device, + const std::string& cache_quant_type_str, + const int num_blocks_x, + const int max_seq_len, + const int max_dec_len, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int draft_token_num, + const bool causal, + cudaStream_t& stream, + paddle::Tensor* out); + + +template void BatchMLAWithPagedKVCacheKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& q, // [token_num, q_head_num, head_dim] + const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim] + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const paddle::Tensor& num_blocks_x_device, + const std::string& cache_quant_type_str, + const int num_blocks_x, + const int max_seq_len, + const int max_dec_len, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int draft_token_num, + const bool causal, + cudaStream_t& stream, + paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h new file mode 100644 index 000000000..128c171ea --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h @@ -0,0 +1,69 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "paddle/extension.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/allocator.h" +#include "append_attn/utils.cuh" + +template +void BatchMLAWithPagedKVCacheKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& q, // [token_num, q_head_num, head_dim] + const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim] + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const paddle::Tensor& num_blocks_x_device, + const std::string& cache_quant_type_str, + const int num_blocks_x, + const int max_seq_len, + const int max_dec_len, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int draft_token_num, + const bool causal, + cudaStream_t& stream, + paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/mla_attn/epilogue.cuh b/custom_ops/gpu_ops/mla_attn/epilogue.cuh new file mode 100644 index 000000000..72d1b5570 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/epilogue.cuh @@ -0,0 +1,175 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ + + +#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_ +#define ATTENTION_HOPPER_EPILOGUE_CUH_ + +#include + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "named_barrier.cuh" +#include "utils.cuh" + +#ifdef DEBUG_MLA +#undef DEBUG_MLA +#endif +// #define DEBUG_MLA + +namespace mla_attn { + +using namespace cute; + +template +struct CollectiveEpilogue { + using DTypeO = typename Ktraits::DTypeO; + static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q; + static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV; + static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO; + using TileShape_PDV = Shape, Int, Int>; + + static constexpr int NUM_WARPS = Ktraits::NUM_WARPS; + static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; + + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})), + decltype(cute::get<1>(TileShape_PDV{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{}))); + + using SmemCopyAtomO = Copy_Atom; + using SharedStorage = cute::array_aligned>; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; + using LayoutT = cute::Layout; + + using ShapeTmpT = cute::Shape; + using StrideTmpT = cute::Shape; + using LayoutTmpT = cute::Layout; + + using ShapeNTMAT = cute::Shape; + using StrideNTMAT = cute::Shape; + using LayoutNTMAT = cute::Layout; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + using TMA_O = decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{}, + select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O + + static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v); // 8 + static_assert(HEAD_DIM_VO % VEC_SIZE == 0); + static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64 + static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0); + static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW; + using TiledCopyOAtom = cute::Copy_Atom, DTypeO>; + using TiledCopyOThrLayout = decltype(cute::make_layout( + cute::make_shape(Int{}, Int{}), LayoutRight{})); + using TiledCopyOValLayout = + decltype(cute::make_layout(cute::make_shape(_1{}, Int{}), LayoutRight{})); + using TiledCopyO = + decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout + TiledCopyOValLayout{} // Val layout + )); + struct Arguments { + DTypeO* O_ptr; + LayoutNTMAT const layout_O; + DTypeO* O_ptr_tmp; + LayoutNTMAT const layout_O_tmp; + }; + + // Device side kernel params + struct Params { + DTypeO* O_ptr; + LayoutNTMAT const layout_O; + DTypeO* O_ptr_tmp; + LayoutNTMAT const layout_O_tmp; + }; + + static Params to_underlying_arguments_ntma(Arguments const& args) { + return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) {} + + template + CUTLASS_DEVICE void store(Params const& epilogue_params, + FrgTensorO const& tOrO, + FrgTensorLSE const& lse, + SharedStorage& shared_storage, + TiledMma tiled_mma, + const int thread_idx, + const int bid, + const int bsz, + const int seq_len_now, + const int start_token_idx, + const int tile_idx, + const int kv_len, + const int chunk_size, + const int max_draft_token_num, + const int o_stride_bsz) { + const int num_chunks = cute::ceil_div(kv_len, chunk_size); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor tOrO_out = convert_type(tOrO); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // make sure gemm done + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS, + /*id=*/static_cast(NamedBarriers::kValueEmpty)); + // r2s + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + // make sure r2s done + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS, + /*id=*/static_cast(NamedBarriers::kValueEmpty)); + TiledCopyO gmem_tiled_copy_O; + auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz; + Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{}); + Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx) + ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx); + Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D) + Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D) + Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D) + Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D)) + Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D)) + Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D)) + + // copy if not out of bound + auto predicate_fn = [&](auto coords) { + auto s_coords = tOcOGroup(_0{}, coords); + return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now); + }; + copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup); + } +}; + +} // namespace mla_attn + +#endif // ATTENTION_HOPPER_EPILOGUE_CUH_ diff --git a/custom_ops/gpu_ops/mla_attn/kernel_traits.cuh b/custom_ops/gpu_ops/mla_attn/kernel_traits.cuh new file mode 100644 index 000000000..116ccb7c8 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/kernel_traits.cuh @@ -0,0 +1,163 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ + +#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ +#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ + +#include + +#include "cute/algorithm/copy.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" + +namespace mla_attn { + +using namespace cute; + +template +struct alignas(16) SharedStorageQKVO { + alignas(16) cute::array_aligned> smem_q; + alignas(16) cute::array_aligned> smem_p; + alignas(16) cute::array_aligned> smem_scale; + union { + alignas(16) cute::array_aligned> smem_kv; + alignas(16) cute::array_aligned> smem_o; + }; + struct { + alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q; + alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv; + }; +}; + +template +struct AttentionKernelTraits { + + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + using DTypeQKAccum = float; + using DTypePVAccum = float; + using NV_TYPE = NV_TYPE_; + + + static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_; + static constexpr int GROUP_SIZE = GROUP_SIZE_; + static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_; + static_assert(BLOCK_SHAPE_Q % 64 == 0); + static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_; + static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_; + static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_; + static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK; + static_assert(HEAD_DIM_QK % 32 == 0); + static_assert(HEAD_DIM_VO % 32 == 0); + + static constexpr int NUM_WARPS = 12; + static constexpr int NUM_THREADS = 384; + static constexpr int NUM_PRODUCER_THREADS = 128; + + using TileShape_QKD = Shape, Int, Int>; + using TileShape_PDV = Shape, Int, Int>; + + static constexpr int NUM_STAGES = NUM_STAGES_; + + using AtomLayoutQKD = Layout, _1, _1>>; + using AtomLayoutPV = Layout, _2, _1>>; + using TiledMmaQK = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), AtomLayoutQKD{})); + using TiledMmaPV = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutPV{})); + using TiledMmaPVSS = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutPV{})); + + static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{}); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), + decltype(cute::get<2>(TileShape_QKD{}))>()); + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int{}))); + using SmemLayoutVt = decltype(composition( + SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}), + get<1>(TileShape_QKD{}), Int{}), + Step<_2, _1, _3>{}))); + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})), + decltype(cute::get<1>(TileShape_PDV{}))>()); + using SmemLayoutV = decltype(tile_to_shape( + SmemLayoutAtomV{}, + make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{}))); + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutVtOneStage = decltype(composition( + SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}), + get<2>(TileShape_PDV{}), Int<1>{}), + Step<_2, _1, _3>{}))); + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})), + decltype(cute::get<1>(TileShape_PDV{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{}))); + + using SmemCopyAtom = Copy_Atom; + + static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32); + using SmemLayoutRowOneStage = Layout>, Stride<_1, _2>>; + using SmemLayoutRowTwoStage = Layout, _2>, Stride<_1, _2, _256>>; + using SmemLayoutRow = std::conditional_t; + + using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})), + decltype(cute::get<1>(TileShape_QKD{}))>()); + using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{}))); + using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int{}, Int{}, Int<2>{}))); + using SmemLayoutP = std::conditional_t; + + using MainloopPipelineQ = typename cutlass::PipelineAsync<1>; + using PipelineStateQ = typename cutlass::PipelineState<1>; + using MainloopPipeline = + std::conditional_t, + typename cutlass::PipelineAsync>; + using PipelineState = typename cutlass::PipelineState; + + using SharedStorage = SharedStorageQKVO; +}; + +} // namespace mla_attn + +#endif diff --git a/custom_ops/gpu_ops/mla_attn/mainloop_load.cuh b/custom_ops/gpu_ops/mla_attn/mainloop_load.cuh new file mode 100644 index 000000000..9c67f601f --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/mainloop_load.cuh @@ -0,0 +1,348 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_ +#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_ + +#include +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "named_barrier.cuh" +#include "utils.cuh" + +#ifdef DEBUG_MLA +#undef DEBUG_MLA +#endif +// #define DEBUG_MLA + +namespace mla_attn { + +using namespace cute; + +template +struct CollectiveMainloop { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using DTypeMD = float; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + using TileShape_PDV = typename Ktraits::TileShape_PDV; + static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{}); + static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); + + static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; + static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK; + static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO; + + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8 + static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512 + static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8 + using AlignmentTypeQ = cute::uint_byte_t(sizeof(DTypeQ)) * kGmemElemsPerLoad>; + using GmemCopyAtomQ = cute::Copy_Atom, DTypeQ>; + static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout< + Shape, Int>, // 32, 8 + Stride, _1>>; + using GmemTiledCopy = decltype(make_tiled_copy( + GmemCopyAtomQ{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + + using GmemLayoutAtomQ = Layout< + Shape, Int>, // 32, 8 + Stride, _1>>; + using GmemTiledCopyQ = decltype(make_tiled_copy( + GmemCopyAtomQ{}, + GmemLayoutAtomQ{}, + Layout>{})); // Val layout, 8 vals per read + + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ; + + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + + using ShapeQT = cute::Shape; + using StrideQT = cute::Shape; + using LayoutQT = cute::Layout; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; + using LayoutT = cute::Layout; + + using ShapeMDT = cute::Shape; + using StrideMDT = cute::Shape; + using LayoutMDT = cute::Layout; + + using TMA_KV = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{} + ), + take<0, 2>(SmemLayoutK{}), + select<1, 2>(TileShape_QKD{}), + _1{})); // no mcast for KV + + static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV; + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ; + using PipelineParamsQ = typename MainloopPipelineQ::Params; + using PipelineStateQ = typename MainloopPipelineQ::PipelineState; + + static constexpr uint32_t TmaTransactionBytesQ = + static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesKV = + static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + + // Host side kernel arguments + struct Arguments { + LayoutQT layout_Q; + LayoutT layout_KV; + LayoutMDT layout_MD; + DTypeQ const* Q_ptr; + DTypeKV const* KV_ptr; + DTypeMD const* m_ptr; + DTypeMD const* d_ptr; + IdType const* kv_block_tables; + IdType const* seq_lens_this_time; + IdType const* seq_lens_encoder; + IdType const* seq_lens_decoder; + IdType const* cumsum_q_seqlens; + IdType const* batch_ids; + IdType const* tile_ids_per_batch; + IdType const* num_blocks_x; + float sm_scale; + int bsz; + int max_block_num; + int max_block_num_per_seq; + int q_stride_bsz; + int q_stride_head_num; + int kv_stride_block_num; + int kv_stride_block_size; + int o_stride_bsz; + int o_stride_head_num; + int chunk_size; + int chunk_num; + int max_draft_token_num; + }; + + // Device side kernel params + struct Params { + LayoutQT layout_Q; + LayoutT layout_KV; + LayoutMDT layout_MD; + DTypeQ *Q_ptr; + DTypeKV* KV_ptr; + DTypeMD* m_ptr; + DTypeMD* d_ptr; + IdType* kv_block_tables; + IdType* seq_lens_this_time; + IdType* seq_lens_encoder; + IdType* seq_lens_decoder; + IdType* cumsum_q_seqlens; + IdType* batch_ids; + IdType* tile_ids_per_batch; + IdType* num_blocks_x; + float sm_scale; + int bsz; + int max_block_num; + int max_block_num_per_seq; + int q_stride_bsz; + int q_stride_head_num; + int kv_stride_block_num; + int kv_stride_block_size; + int o_stride_bsz; + int o_stride_head_num; + int chunk_size; + int chunk_num; + int max_draft_token_num; + TMA_KV tma_load_KV; + }; + + static Params to_underlying_arguments(Arguments const& args) { + TMA_KV tma_load_KV; + if constexpr (USE_TMA_LOAD_KV) { + Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV); + tma_load_KV = + make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{}); + } + return {args.layout_Q, + args.layout_KV, + args.layout_MD, + const_cast(args.Q_ptr), + const_cast(args.KV_ptr), + const_cast(args.m_ptr), + const_cast(args.d_ptr), + const_cast(args.kv_block_tables), + const_cast(args.seq_lens_this_time), + const_cast(args.seq_lens_encoder), + const_cast(args.seq_lens_decoder), + const_cast(args.cumsum_q_seqlens), + const_cast(args.batch_ids), + const_cast(args.tile_ids_per_batch), + const_cast(args.num_blocks_x), + args.sm_scale, + args.bsz, + args.max_block_num, + args.max_block_num_per_seq, + args.q_stride_bsz, + args.q_stride_head_num, + args.kv_stride_block_num, + args.kv_stride_block_size, + args.o_stride_bsz, + args.o_stride_head_num, + args.chunk_size, + args.chunk_num, + args.max_draft_token_num, + tma_load_KV + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + if constexpr (USE_TMA_LOAD_KV) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void load_q(Params const& mainloop_params, + MainloopPipelineQ pipeline_q, + PipelineStateQ& smem_pipe_write_q, + SharedStorage& shared_storage, + const int thread_idx, + const int bid) { + int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid]; + int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx; + Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q); + Tensor gQ = + local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{}); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor cQ = cute::make_identity_tensor(gQ.shape()); + + GmemTiledCopyQ gmem_tiled_copy_q; + auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx); + Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ); + Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ); + Tensor tQcQGroup = flatten_1(tQcQ); + + int valid_q_size = mainloop_params.seq_lens_this_time[bid]; + auto q_predicate_fn = [&](auto coords) { + auto s_coords = tQcQGroup(_0{}, coords); + return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size); + }; + Tensor tQgQiGroup = flatten_1(tQgQ); + Tensor tQsQiGroup = flatten_1(tQsQ); + + pipeline_q.producer_acquire(smem_pipe_write_q); + copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup); + pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_q; + } + + template + CUTLASS_DEVICE void load_kv(Params const& mainloop_params, + MainloopPipeline pipeline_kv, + PipelineState& smem_pipe_write_kv, + SharedStorage& shared_storage, + const int bid, + const int kv_len, + const int tile_idx) { + int thread_idx = threadIdx.x; + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{}); + Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV); + Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _); + GmemTiledCopy gmem_tiled_copy_kv; + auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx); + + static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); + const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_tile_idx = start_len / BLOCK_SHAPE_KV; + const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + + auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1))); + + Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV); + Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); + + for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) { + const int block_idx = kv_block_tables(bid, kv_tile_idx); + pipeline_kv.producer_acquire(smem_pipe_write_kv); + Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx)); + Tensor tKsKiGroup = + flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index())); + copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup); + pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_kv; + } + } + + template + CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params, + MainloopPipeline pipeline_kv, + PipelineState& smem_pipe_write_kv, + SharedStorage& shared_storage, + const int bid, + const int kv_len, + const int tile_idx) { + int thread_idx = threadIdx.x; + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{}); + + Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape()); + + // Prepare the TMA loads + Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _); + auto [tKgK, tKsK] = + tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{}, + group_modes<0, 2>(sK), group_modes<0, 2>(gKV)); + + static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); + const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_tile_idx = start_len / BLOCK_SHAPE_KV; + const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + + auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1))); + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { +#pragma unroll 2 + for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) { + const int block_idx = kv_block_tables(bid, kv_tile_idx); + pipeline_kv.producer_acquire(smem_pipe_write_kv); + copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0), + tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index())); + ++smem_pipe_write_kv; + } + } + } +}; + +} // namespace mla_attn + +#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ diff --git a/custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh b/custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh new file mode 100644 index 000000000..77d059583 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh @@ -0,0 +1,500 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ +#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ + +#include +#include +#include +#include +#include "named_barrier.cuh" + +// #define DEBUG_MLA + +namespace mla_attn { + +template +CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, + MainloopPipelineQ pipeline_q, + PipelineStateQ& smem_pipe_read_q, + MainloopPipeline pipeline_kv, + PipelineState& smem_pipe_read_kv, + FrgTensorO& tOrO, + AttentionUpdater& attention_updater, + const int thread_idx, + const int bid, + const int kv_len, + const int qo_len, + const int tile_idx, + SharedStorage& shared_storage) { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using DTypeMD = typename Ktraits::DTypeO; + using DTypeQKAccum = typename Ktraits::DTypeQKAccum; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutP = typename Ktraits::SmemLayoutP; + using SmemLayoutRow = typename Ktraits::SmemLayoutRow; + using SmemCopyAtom = typename Ktraits::SmemCopyAtom; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage; + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size); + + static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{}); + static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{}); + Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{}); + Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{}); + Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{}); + Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{}); + Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head) + Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _); + + typename Ktraits::TiledMmaQK tiled_mma_qk; + auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx); + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk); + auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); + Tensor tPsP = smem_thr_copy_P.partition_D(sPSS); + Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup); + + typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss; + auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx); + Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1); + Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2); + Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS); + + const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_tile_idx = start_len / BLOCK_SHAPE_KV; + const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + int kv_tile_idx = end_tile_idx; + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 1) { + // consumer 0, compute qk + Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ); + Tensor tSrK = threadMmaQK.partition_fragment_B(sK); + + constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1; + auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; }; + bool is_first_step = true; + // wait q + consumer_wait(pipeline_q, smem_pipe_read_q); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); +#pragma unroll 1 + for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) { + // wait kv + consumer_wait(pipeline_kv, smem_pipe_read_kv); + // gemm qk + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()), + tSrS); + // mask + if (masking_step > 0) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV; + if constexpr (!CAUSAL) { // Just masking based on col + if (kv_idx >= kv_len) { + tSrS(i) = AttentionUpdater::fill_value; + } + } else { + if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { + tSrS(i) = AttentionUpdater::fill_value; + } + } + } + } + + // update s (exp(s - m)) + Tensor scale_o = is_first_step ? attention_updater.update(tSrS) : attention_updater.update(tSrS); + is_first_step = false; + + Tensor convert_tSrS = convert_type(tSrS); + Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS); + + // gather qk gemm res + cute::copy(smem_tiled_copy_P, tPrP, tPsP); + cute::copy(scale_o, tScalesScale); + // r2s fence wgmma + cutlass::arch::fence_view_async_shared(); + // make sure r2s all done + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWarpSchedulerWG1)); + + attention_updater.rescale_o(tOrO, scale_o); + + // pv gemm + if (smem_pipe_read_kv.index() == 0) { + gemm(tiled_mma_pv_ss, tOrP_CS2, + tOrV1(_, _, _, _0{}), tOrO); + } else { + gemm(tiled_mma_pv_ss, tOrP_CS2, + tOrV2(_, _, _, _0{}), tOrO); + } + + pipeline_kv.consumer_release(smem_pipe_read_kv); + ++smem_pipe_read_kv; + // sync WG1 WG2 + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWG1WG2Sync)); + } + // release q + pipeline_q.consumer_release(smem_pipe_read_q); + ++smem_pipe_read_q; + + // normalize + Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum + if (chunk_num_this_seq == 1) { + // norm + cute::copy(scale_o, tScalesScale); + + cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWarpSchedulerWG2)); + attention_updater.rescale_o(tOrO, scale_o); + } + + // WG1 write m,d back to gmem + if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9 + const int warp_idx = thread_idx / 32; +#pragma unroll + for (int w_i = 0; w_i < 2; ++w_i) { + const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i; + const int token_idx = token_group_idx / Ktraits::GROUP_SIZE; + + if (token_idx < qo_len) { + const int head_idx = token_group_idx % Ktraits::GROUP_SIZE; + const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE; + const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx; + mM(write_idx) = static_cast(attention_updater.row_max(w_i)); + mD(write_idx) = static_cast(attention_updater.row_sum(w_i)); + } + } + } + } else if (warp_group_idx == 2) { + // consumer 1, compute pv + Tensor scale_o = make_tensor(Shape<_2>{}); + for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) { + // wait kv + consumer_wait(pipeline_kv, smem_pipe_read_kv); + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWarpSchedulerWG1)); + + // A: tPsP + cute::copy(tScalesScale, scale_o); + + // rescale + attention_updater.rescale_o(tOrO, scale_o); + if (smem_pipe_read_kv.index() == 0) { + gemm(tiled_mma_pv_ss, tOrP_CS2, + tOrV1(_, _, _, _0{}), tOrO); + } else { + gemm(tiled_mma_pv_ss, tOrP_CS2, + tOrV2(_, _, _, _0{}), tOrO); + } + + pipeline_kv.consumer_release(smem_pipe_read_kv); + ++smem_pipe_read_kv; + // sync WG1 WG2 + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWG1WG2Sync)); + } + if (chunk_num_this_seq == 1) { + // norm + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWarpSchedulerWG2)); + cute::copy(tScalesScale, scale_o); + attention_updater.rescale_o(tOrO, scale_o); + } + } + return; +} + +template +CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params, + MainloopPipelineQ pipeline_q, + PipelineStateQ& smem_pipe_read_q, + MainloopPipeline pipeline_kv, + PipelineState& smem_pipe_read_kv, + FrgTensorO& tOrO, + AttentionUpdater& attention_updater, + const int thread_idx, + const int bid, + const int kv_len, + const int qo_len, + const int tile_idx, + SharedStorage& shared_storage) { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using DTypeMD = typename Ktraits::DTypeO; // !!! bf16 + using DTypeQKAccum = typename Ktraits::DTypeQKAccum; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutP = typename Ktraits::SmemLayoutP; + using SmemLayoutRow = typename Ktraits::SmemLayoutRow; + using SmemCopyAtom = typename Ktraits::SmemCopyAtom; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage; + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size); + + static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{}); + static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{}); + Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{}); + Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{}); + Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{}); + Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{}); + Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{}); + Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); + Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _); + + Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{}); + + typename Ktraits::TiledMmaQK tiled_mma_qk; + auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx); + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk); + auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); + Tensor tPsP = smem_thr_copy_P.partition_D(sPSS); + Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _); + + typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss; + auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx); + Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1); + Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2); + Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3); + Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4); + Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS); + + const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_tile_idx = start_len / BLOCK_SHAPE_KV; + const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + int kv_tile_idx = end_tile_idx; + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 1) { + // consumer 0, compute qk + Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ); + Tensor tSrK = threadMmaQK.partition_fragment_B(sK); + auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; }; + // wait q + consumer_wait(pipeline_q, smem_pipe_read_q); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + // wait k + consumer_wait(pipeline_kv, smem_pipe_read_kv); + // first qk gemm + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()), + tSrS); + // mask + { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV; + if constexpr (!CAUSAL) { // Just masking based on col + if (kv_idx >= kv_len) { + tSrS(i) = AttentionUpdater::fill_value; + } + } else { + if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { + tSrS(i) = AttentionUpdater::fill_value; + } + } + } + } + + Tensor scale_o = attention_updater.update(tSrS); + Tensor tPrP = smem_thr_copy_P.retile_S(convert_type(tSrS)); + // gather qk gemm res + cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2)); + cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2)); + // r2s fence wgmma + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWarpSchedulerWG1)); + + constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0; + --kv_tile_idx; + for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv; + ++smem_pipe_read_kv; + // wait next kv + consumer_wait(pipeline_kv, smem_pipe_read_kv); + + // gemm next qk + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()), + tSrS); + attention_updater.rescale_o(tOrO); + // last pv gemm + if (smem_pipe_read_kv_cur.index() == 0) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2), + tOrV1(_, _, _, _0{}), tOrO); + } else if (smem_pipe_read_kv_cur.index() == 1) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2), + tOrV2(_, _, _, _0{}), tOrO); + } else if (smem_pipe_read_kv_cur.index() == 2) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2), + tOrV3(_, _, _, _0{}), tOrO); + } else { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2), + tOrV4(_, _, _, _0{}), tOrO); + } + // wait cur qk gemm + warpgroup_wait<1>(); + // mask p + if (masking_step > 0) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV; + if constexpr (!CAUSAL) { // Just masking based on col + if (kv_idx >= kv_len) { + tSrS(i) = AttentionUpdater::fill_value; + } + } else { + if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { + tSrS(i) = AttentionUpdater::fill_value; + } + } + } + } + // update s (exp(s - m)) + Tensor scale_o = attention_updater.update(tSrS); + Tensor tPrP = smem_thr_copy_P.retile_S(convert_type(tSrS)); + + // gather qk gemm res + cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2)); + cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2)); + // r2s fence wgmma + cutlass::arch::fence_view_async_shared(); + // make sure tSrS r2s done + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWarpSchedulerWG1)); + // wait last pv gemm + warpgroup_wait<0>(); + // release last kv + pipeline_kv.consumer_release(smem_pipe_read_kv_cur); + } + // release q + pipeline_q.consumer_release(smem_pipe_read_q); + ++smem_pipe_read_q; + // compute last pv + attention_updater.rescale_o(tOrO); + if (smem_pipe_read_kv.index() == 0) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV1(_, _, _, _0{}), tOrO); + } else if (smem_pipe_read_kv.index() == 1) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV2(_, _, _, _0{}), tOrO); + } else if (smem_pipe_read_kv.index() == 2) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV3(_, _, _, _0{}), tOrO); + } else { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV4(_, _, _, _0{}), tOrO); + } + scale_o = attention_updater.finalize(tSrS); + warpgroup_wait<0>(); + // release last kv + pipeline_kv.consumer_release(smem_pipe_read_kv); + ++smem_pipe_read_kv; + if (chunk_num_this_seq == 1) { + // norm + cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2)); + + cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWG1WG2LastSync)); + attention_updater.rescale_o(tOrO); + } + // WG1 write m,d back to gmem + if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9 + const int warp_idx = thread_idx / 32; +#pragma unroll + for (int w_i = 0; w_i < 2; ++w_i) { + const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i; + const int token_idx = token_group_idx / Ktraits::GROUP_SIZE; + + if (token_idx < qo_len) { + const int head_idx = token_group_idx % Ktraits::GROUP_SIZE; + const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE; + const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx; + mM(write_idx) = static_cast(attention_updater.row_max(w_i)); + mD(write_idx) = static_cast(attention_updater.row_sum(w_i)); + } + } + } + } else if (warp_group_idx == 2) { + // consumer 1, compute pv + Tensor scale_o = make_tensor(Shape<_2>{}); + for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) { + consumer_wait(pipeline_kv, smem_pipe_read_kv); + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWarpSchedulerWG1)); + // A: tPsP + cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o); + // rescale + attention_updater.rescale_o(tOrO, scale_o); + if (smem_pipe_read_kv.index() == 0) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV1(_, _, _, _0{}), tOrO); + } else if (smem_pipe_read_kv.index() == 1) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV2(_, _, _, _0{}), tOrO); + } else if (smem_pipe_read_kv.index() == 2) { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV3(_, _, _, _0{}), tOrO); + } else { + gemm(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2), + tOrV4(_, _, _, _0{}), tOrO); + } + pipeline_kv.consumer_release(smem_pipe_read_kv); + ++smem_pipe_read_kv; + } + if (chunk_num_this_seq == 1) { + // norm + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast(NamedBarriers::kWG1WG2LastSync)); + cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o); + attention_updater.rescale_o(tOrO, scale_o); + } + } + return; +} + +} // namespace mla_attn + +#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ diff --git a/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh b/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh new file mode 100644 index 000000000..c958264c1 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh @@ -0,0 +1,575 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ + +#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_ +#define ATTENTION_HOPPER_PREFILL_SM90_CUH_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "attention_updater.cuh" +#include "cute/tensor.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "epilogue.cuh" +#include "helper.h" +#include "kernel_traits.cuh" +#include "mainloop_mma.cuh" +#include "mainloop_load.cuh" +#include "utils.cuh" + +#ifdef DEBUG_MLA +#undef DEBUG_MLA +#endif +// #define DEBUG_MLA + +namespace mla_attn { + +using namespace cute; + +template +struct Params { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + + alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head] + alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head] + alignas(16) DTypeO *O; // [token_num, head_num, dim_head] + alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head] + alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num] + alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num] + + alignas(16) IdType *block_tables; + alignas(16) IdType *seq_lens_this_time; + alignas(16) IdType *seq_lens_encoder; + alignas(16) IdType *seq_lens_decoder; + alignas(16) IdType *cumsum_q_seqlens; + alignas(16) IdType *padding_offsets; + + alignas(16) IdType *batch_ids; + alignas(16) IdType *tile_ids_per_batch; + alignas(16) IdType *num_blocks_x; + + + uint32_t q_stride_bsz; + uint32_t q_stride_head_num; + + uint32_t kv_stride_block_num; + uint32_t kv_stride_block_size; + + uint32_t o_stride_bsz; + uint32_t o_stride_head_num; + + int bsz; + int token_num; + int max_seq_len; + int max_block_num; + int max_block_num_per_seq; + int q_num_head; + int qk_head_dim; + int vo_head_dim; + int block_size; + int max_draft_token_num; + int chunk_size; + int chunk_num; + int num_blocks_x_int; + + float sm_scale; +}; + +#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else if (group_size == 64) { \ + constexpr size_t GROUP_SIZE = 64; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size: ", group_size); \ + return cudaErrorNotSupported; \ + } + +template +__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1) +MLAWithKVCacheKernel(CUTE_GRID_CONSTANT + typename CollectiveMainloop::Params const mainloop_params, + CUTE_GRID_CONSTANT + typename CollectiveEpilogue::Params const epilogue_params) { + + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using DTypeO = typename Ktraits::DTypeO; + using DTypeQKAccum = typename Ktraits::DTypeQKAccum; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + using TileShape_PDV = typename Ktraits::TileShape_PDV; + + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; + static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS; + static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q; + static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV; + const int num_blocks_x = mainloop_params.num_blocks_x[0]; + + static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV; + + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ; + using PipelineParamsQ = typename MainloopPipelineQ::Params; + using PipelineStateQ = typename MainloopPipelineQ::PipelineState; + + extern __shared__ char shared_memory[]; + auto& shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + if constexpr (use_tma_load_kv) { + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NUM_MMA_THREADS; + } else { + pipeline_params.producer_arv_count = NUM_COPY_THREADS; + pipeline_params.consumer_arv_count = NUM_MMA_THREADS; + } + + PipelineParamsQ pipeline_params_q; + pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer + : MainloopPipelineQ::ThreadCategory::Consumer; + pipeline_params_q.producer_arv_count = NUM_COPY_THREADS; + pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk + + + MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q); + MainloopPipeline pipeline_kv = [&] { + if constexpr (use_tma_load_kv) { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV; + return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params, + /*cluster_shape=*/Shape<_1, _1, _1>{}); + } else { + return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params); + } + }(); + __syncthreads(); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + + if (warp_group_idx == 0) { + // producer + if constexpr(USE_REG_EALLOC) { + cutlass::arch::warpgroup_reg_dealloc<72>(); + } + const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0); + + PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state(); + if constexpr(USE_FIXED_BLOCK) { + for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { + const int bid = mainloop_params.batch_ids[i]; + const int tile_id = mainloop_params.tile_ids_per_batch[i]; + const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; + const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; + const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; + const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, + /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); + + // load Q + collective_mainloop.load_q( + mainloop_params, + pipeline_q, + smem_pipe_write_q, + shared_storage, + threadIdx.x, + bid); + + if constexpr (!use_tma_load_kv) { + // load kv + collective_mainloop.load_kv( + mainloop_params, + pipeline_kv, + smem_pipe_write_kv, + shared_storage, + bid, + seq_len_decoder_now, + tile_id + ); + } else { + if (warp_idx_in_warpgroup == 0) { + // load kv tma + collective_mainloop.load_kv_tma( + mainloop_params, + pipeline_kv, + smem_pipe_write_kv, + shared_storage, + bid, + seq_len_decoder_now, + tile_id + ); + } + } + } + } else { + const int block_id = blockIdx.x; + const int bid = mainloop_params.batch_ids[block_id]; + const int tile_id = mainloop_params.tile_ids_per_batch[block_id]; + const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; + const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; + const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; + const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, + /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); + + // load Q + collective_mainloop.load_q( + mainloop_params, + pipeline_q, + smem_pipe_write_q, + shared_storage, + threadIdx.x, + bid); + + if constexpr (!use_tma_load_kv) { + // load kv + collective_mainloop.load_kv( + mainloop_params, + pipeline_kv, + smem_pipe_write_kv, + shared_storage, + bid, + seq_len_decoder_now, + tile_id + ); + } else { + if (warp_idx_in_warpgroup == 0) { + // load kv tma + collective_mainloop.load_kv_tma( + mainloop_params, + pipeline_kv, + smem_pipe_write_kv, + shared_storage, + bid, + seq_len_decoder_now, + tile_id + ); + } + } + } + } else { + // consumer + if constexpr(USE_REG_EALLOC) { + cutlass::arch::warpgroup_reg_alloc<216>(); + } + PipelineStateQ smem_pipe_read_q; + PipelineState smem_pipe_read_kv; + + typename Ktraits::TiledMmaPVSS tiled_mma_pv; + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); + + auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale); + if constexpr(USE_FIXED_BLOCK) { + for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { + clear(tOrO); + clear(attention_updater.scores_scale); + const int bid = mainloop_params.batch_ids[i]; + const int tile_id = mainloop_params.tile_ids_per_batch[i]; + const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; + const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; + const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; + const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, + /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); + + if constexpr (BLOCK_SHAPE_KV == 64) { + mma_f16( + mainloop_params, + pipeline_q, + smem_pipe_read_q, + pipeline_kv, + smem_pipe_read_kv, + tOrO, + attention_updater, + threadIdx.x - NUM_COPY_THREADS, + bid, + seq_len_decoder_now, + seq_len_now, + tile_id, + shared_storage); + } else if (BLOCK_SHAPE_KV == 32) { + mma_f16_two_stages( + mainloop_params, + pipeline_q, + smem_pipe_read_q, + pipeline_kv, + smem_pipe_read_kv, + tOrO, + attention_updater, + threadIdx.x - NUM_COPY_THREADS, + bid, + seq_len_decoder_now, + seq_len_now, + tile_id, + shared_storage); + } + + collective_epilogue.store( + epilogue_params, + tOrO, + attention_updater.get_lse(), + shared_storage, + tiled_mma_pv, + threadIdx.x - NUM_COPY_THREADS, + bid, + mainloop_params.bsz, + seq_len_now, + start_token_idx, + tile_id, + seq_len_decoder_now, + mainloop_params.chunk_size, + mainloop_params.max_draft_token_num, + mainloop_params.o_stride_bsz); + } + } else { + const int block_id = blockIdx.x; + clear(tOrO); + clear(attention_updater.scores_scale); + const int bid = mainloop_params.batch_ids[block_id]; + const int tile_id = mainloop_params.tile_ids_per_batch[block_id]; + const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; + const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; + const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; + const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, + /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); + + if constexpr (BLOCK_SHAPE_KV == 64) { + mma_f16( + mainloop_params, + pipeline_q, + smem_pipe_read_q, + pipeline_kv, + smem_pipe_read_kv, + tOrO, + attention_updater, + threadIdx.x - NUM_COPY_THREADS, + bid, + seq_len_decoder_now, + seq_len_now, + tile_id, + shared_storage); + } else if (BLOCK_SHAPE_KV == 32) { + mma_f16_two_stages( + mainloop_params, + pipeline_q, + smem_pipe_read_q, + pipeline_kv, + smem_pipe_read_kv, + tOrO, + attention_updater, + threadIdx.x - NUM_COPY_THREADS, + bid, + seq_len_decoder_now, + seq_len_now, + tile_id, + shared_storage); + } + + collective_epilogue.store( + epilogue_params, + tOrO, + attention_updater.get_lse(), + shared_storage, + tiled_mma_pv, + threadIdx.x - NUM_COPY_THREADS, + bid, + mainloop_params.bsz, + seq_len_now, + start_token_idx, + tile_id, + seq_len_decoder_now, + mainloop_params.chunk_size, + mainloop_params.max_draft_token_num, + mainloop_params.o_stride_bsz); + } + } +} + + +template +cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using IdType = typename KernelTraits::IdType; + using NV_TYPE = typename KernelTraits::NV_TYPE; + + using CollectiveMainloop = + CollectiveMainloop; + using CollectiveEpilogue = CollectiveEpilogue; + + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ + make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q + make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)), + make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})), + params.Q, + params.KV, + params.m, + params.d, + params.block_tables, + params.seq_lens_this_time, + params.seq_lens_encoder, + params.seq_lens_decoder, + params.cumsum_q_seqlens, + params.batch_ids, + params.tile_ids_per_batch, + params.num_blocks_x, + params.sm_scale, + params.bsz, + params.max_block_num, + params.max_block_num_per_seq, + params.q_stride_bsz, + params.q_stride_head_num, + params.kv_stride_block_num, + params.kv_stride_block_size, + params.o_stride_bsz, + params.o_stride_head_num, + params.chunk_size, + params.chunk_num, + params.max_draft_token_num + }); + typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({ + params.O, + make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O + params.O_tmp, + make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp + }); + + // Get the ptr to kernel function. + auto kernel = + MLAWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + int device; + cudaGetDevice(&device); + int multiprocessor_count; + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device); + int act_blocks_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size); + + int gridx; + if constexpr(USE_FIXED_BLOCK) { + gridx = multiprocessor_count; + } else { + gridx = params.num_blocks_x_int; + } + dim3 grid_dims = {gridx, 1, 1}; + static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(ctaSize, 1, 1); + kernel<<>>( + mainloop_params, epilogue_params + ); + if (params.chunk_num > 1) { + constexpr int vec_size = 16 / sizeof(DTypeO); + constexpr int merge_block_size = 256; + constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size; + constexpr int blocky = (merge_block_size + blockx - 1) / blockx; + dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large + dim3 blocks_merge(blockx, blocky); + merge_multi_chunks_kernel<<>>( + reinterpret_cast(params.O_tmp), + params.m, + params.d, + params.seq_lens_this_time, + params.seq_lens_decoder, + params.seq_lens_encoder, + params.padding_offsets, + reinterpret_cast(params.O), + params.max_seq_len, + params.chunk_num, + params.q_num_head, + params.chunk_size, + params.vo_head_dim, + params.token_num, + params.bsz, + params.max_draft_token_num + ); + } + return cudaSuccess; +} + +template +cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) { + constexpr bool CAUSAL = true; + if constexpr (HEAD_DIM_QK == 576) { + DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE, + BatchMLAWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + CAUSAL, + Params, + USE_REG_EALLOC, + USE_FIXED_BLOCK>(params, stream);) + } else { + return cudaErrorNotSupported; + } + return cudaSuccess; +}; + +} // namespace mla_attn + +#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_ diff --git a/custom_ops/gpu_ops/mla_attn/named_barrier.cuh b/custom_ops/gpu_ops/mla_attn/named_barrier.cuh new file mode 100644 index 000000000..bf2f8bf21 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/named_barrier.cuh @@ -0,0 +1,47 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ + +#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ +#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ + +#include + +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" + +namespace mla_attn { + +enum class NamedBarriers { + kQueryEmpty = 0, + kValueEmpty = 1, + kWarpSchedulerWG1 = 2, + kWarpSchedulerWG2 = 3, + kWarpSchedulerWG3 = 4, + kPrefetchIndices = 5, + kOdone = 6, + kWG1WG2Sync = 7, + kWG0WG1WG2Sync = 8, + kWG1WG2LastSync = 9, +}; + +} // namespace mla_attn + +#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_ diff --git a/custom_ops/gpu_ops/mla_attn/utils.cuh b/custom_ops/gpu_ops/mla_attn/utils.cuh new file mode 100644 index 000000000..e3413f752 --- /dev/null +++ b/custom_ops/gpu_ops/mla_attn/utils.cuh @@ -0,0 +1,351 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ATTENTION_HOPPER_UTILS_CUH_ +#define ATTENTION_HOPPER_UTILS_CUH_ + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include +#include +#include +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include +#include +#include +#include + +#include +#include +#include +#include "cutlass/fast_math.h" + +namespace mla_attn { + +using namespace cute; + +template +CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) { + Tensor tensor_flatten = cute::flatten(tensor); + return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten); +} + +CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride, + int64_t h_stride) { + return make_layout(make_shape(nnz, head_dim, num_heads), + make_stride(n_stride, cute::_1{}, h_stride)); +} + +CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) { + return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads))); +} + +template +CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)), + make_coord(offset, _0{})); + auto g_sequence = + make_tensor(g_offset.data(), + make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride())); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + +template +CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape, + int head_idx, int offset, int seq_len) { + auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset)); + + auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len), + cute::make_shape(shape<0>(m_tensor)))); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_)); + return g_tensor; +} + +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, +// MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); +}; + +// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, +// MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) + return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), + make_layout(get<2, 1>(l), get<2>(acc_layout))); +}; + +template +__forceinline__ __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +template +__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB, + TensorC& tCrC) { + constexpr bool Is_RS = + !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + if constexpr (init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } +} + +#define HOSTDEVICE __host__ __device__ + +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; + + HOSTDEVICE inline const T& operator[](int i) const { return val[i]; } + HOSTDEVICE inline T& operator[](int i) { return val[i]; } +}; + +template +HOSTDEVICE inline void Load(const T* addr, AlignedVector* vec) { + const AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *vec = *addr_vec; +} + +template +HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { + AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *addr_vec = vec; +} + +template +struct prefill_softmax_state_t { + AlignedVector o; + float m; + float d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + __device__ __forceinline__ void merge(const AlignedVector& other_o, + const float other_m, + const float other_d) { + float m_prev = m, d_prev = d; + m = max(m_prev, other_m); + const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); + const T scale1_T = static_cast(scale1), scale2_T = static_cast(scale2); + d = d_prev * scale1 + other_d * scale2; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1_T + other_o[i] * scale2_T; + } + } + + __device__ __forceinline__ void normalize() { + const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } +}; + +template +__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim] + const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads] + const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads] + const int * __restrict__ seq_lens_this_time, + const int * __restrict__ seq_lens_decoder, + const int * __restrict__ seq_lens_encoder, + const int * __restrict__ padding_offsets, + T * __restrict__ out, // [token_num, num_heads, head_dim] + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int chunk_size, + const int head_dim, + const int token_num, + const int bsz, + const int max_draft_token_num) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int hid = blockIdx.y; + __shared__ T smem[bdy * HEAD_DIM]; + __shared__ float md_smem[bdy * 2]; + for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { + const uint32_t ori_token_id = qid + padding_offsets[qid]; + const uint32_t bid = ori_token_id / max_seq_len; + const int seq_len_q = seq_lens_this_time[bid]; + if (seq_len_q == 0) continue; + const uint32_t local_seq_id = ori_token_id % max_seq_len; + int seq_len_kv = seq_lens_decoder[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size); + if (num_chunks_this_seq <= 1) { + // not need merge + continue; + } + + using LoadT = AlignedVector; + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } + + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset; + offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + // store ty res + Store(res_vec, &smem[ty * head_dim + vid * vec_size]); + md_smem[2 * ty] = m; + md_smem[2 * ty + 1] = d; + __syncthreads(); + if (ty == 0) { + // merge bdy + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + st.normalize(); + Store(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + __syncthreads(); + } +} + +} // namespace mla_attn + +#endif // ATTENTION_HOPPER_UTILS_CUH_ diff --git a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h index d7f9f17dc..6de2cd83d 100644 --- a/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h +++ b/custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h @@ -1255,8 +1255,6 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; - FragB frag_zp_0; - FragB frag_zp_1; int zp_quant_0, zp_quant_1; if constexpr (w_type.size_bits() == 4) { diff --git a/custom_ops/gpu_ops/multi_head_latent_attention.cu b/custom_ops/gpu_ops/multi_head_latent_attention.cu new file mode 100644 index 000000000..2d22ce225 --- /dev/null +++ b/custom_ops/gpu_ops/multi_head_latent_attention.cu @@ -0,0 +1,469 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/multi_head_latent_attention_kernel.h" +#include "mla_attn/batch_mla_with_paged_kv_cache.h" + +template +std::vector MultiHeadLatentAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& query, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& max_dec_len_this_time, + const paddle::Tensor& max_len_kv, + const paddle::optional& attn_mask, + const paddle::optional& query_bias, + const paddle::optional& query_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const std::string& cache_quant_type_str, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + typedef PDTraits traits_; + typedef typename traits_::data_t data_t; + + int decoder_num_blocks_data = decoder_num_blocks_cpu.data()[0]; + int max_dec_len_this_time_data = max_dec_len_this_time.data()[0]; + int max_len_kv_data = max_len_kv.data()[0]; + + const bool mla_use_tensorcore = get_mla_use_tensorcore(); + auto sm_version = GetSMVersion(); + if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) { + PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90."); + } + + auto main_stream = query.stream(); + + paddle::Tensor fmha_out = paddle::full( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v}, + 0, + D, + query.place()); + + if (max_dec_len_this_time_data > 0) { + if (mla_use_tensorcore) { + BatchMLAWithPagedKVCacheKernel(meta_data, + query, + key_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + padding_offsets, + cum_offsets, + block_tables, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + cache_quant_type_str, + decoder_num_blocks_data, + max_input_length, + max_len_kv_data, + softmax_scale, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + main_stream, + &fmha_out); + } else { + DecodeMLAAttentionKernel( + meta_data, + query, // [token_num, num_heads, head_dim] + key_cache, + value_cache, + attn_mask, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, // q_seq_len is 1 + seq_lens_decoder, + padding_offsets, + cum_offsets, + block_tables, + max_input_length, + max_len_kv_data, + softmax_scale, + out_linear_in_scale, + causal, + main_stream, + &fmha_out); + } + } + return {fmha_out}; +} + +std::vector MultiHeadLatentAttention( + const paddle::Tensor& query, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& max_dec_len_this_time, + const paddle::Tensor& max_len_kv, + const paddle::optional& attn_mask, + const paddle::optional& query_bias, + const paddle::optional& query_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const int nope_size, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + AppendAttnMetaData meta_data; + + const auto& query_dims = query.dims(); + const auto& key_cache_dims = key_cache.dims(); + const int q_hidden_size = query_dims[query_dims.size() - 1]; + meta_data.token_nums = query_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + meta_data.head_dims_v = nope_size; + meta_data.q_num_heads = q_hidden_size / meta_data.head_dims; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = cum_offsets.dims()[0]; + + switch (query.dtype()) { + case paddle::DataType::BFLOAT16: { + return MultiHeadLatentAttentionKernel( + meta_data, + query, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + cu_seqlens_q, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + decoder_num_blocks_cpu, + max_enc_len_this_time, + max_dec_len_this_time, + max_len_kv, + attn_mask, + query_bias, + query_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + max_input_length, + softmax_scale, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + speculate_decoder); + } + case paddle::DataType::FLOAT16: { + return MultiHeadLatentAttentionKernel( + meta_data, + query, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + cu_seqlens_q, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + decoder_num_blocks_cpu, + max_enc_len_this_time, + max_dec_len_this_time, + max_len_kv, + attn_mask, + query_bias, + query_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + max_input_length, + softmax_scale, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + speculate_decoder); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + } +} + +std::vector> MultiHeadLatentAttentionInferShape( + const std::vector& query_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& padding_offsets_shape, + const std::vector& cum_offsets_shape, + const std::vector& block_tables_shape, + const std::vector& encoder_batch_ids_shape, + const std::vector& encoder_tile_ids_per_batch_shape, + const std::vector& encoder_num_blocks_shape, + const std::vector& kv_batch_ids_shape, + const std::vector& kv_tile_ids_per_batch_shape, + const std::vector& kv_num_blocks_shape, + const std::vector& decoder_batch_ids_shape, + const std::vector& decoder_tile_ids_per_batch_shape, + const std::vector& decoder_num_blocks_shape, + const std::vector& decoder_num_blocks_cpu_shape, + const std::vector& max_enc_len_this_time_shape, + const std::vector& max_dec_len_this_time_shape, + const std::vector& max_len_kv_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& query_bias_shape, + const paddle::optional>& query_out_scales_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& out_linear_shifts_shape, + const paddle::optional>& out_linear_smooths_shape, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const int nope_size, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + const int token_num = query_shape[0]; + const int kv_num_heads = key_cache_shape[1]; + const int head_dim_qk = key_cache_shape[3]; + const int head_dim_v = nope_size; + const int q_hidden_size = query_shape[query_shape.size() - 1]; + const int num_heads = q_hidden_size / head_dim_qk; + return {{token_num, num_heads * head_dim_v}}; +} + +std::vector MultiHeadLatentAttentionInferDtype( + const paddle::DataType& query_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& padding_offsets_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& encoder_batch_ids_dtype, + const paddle::DataType& encoder_tile_ids_per_batch_dtype, + const paddle::DataType& encoder_num_blocks_dtype, + const paddle::DataType& kv_batch_ids_dtype, + const paddle::DataType& kv_tile_ids_per_batch_dtype, + const paddle::DataType& kv_num_blocks_dtype, + const paddle::DataType& decoder_batch_ids_dtype, + const paddle::DataType& decoder_tile_ids_per_batch_dtype, + const paddle::DataType& decoder_num_blocks_dtype, + const paddle::DataType& decoder_num_blocks_cpu_dtype, + const paddle::DataType& max_enc_len_this_time_dtype, + const paddle::DataType& max_dec_len_this_time_dtype, + const paddle::DataType& max_len_kv_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& query_bias_dtype, + const paddle::optional& query_out_scales_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& out_linear_shifts_dtype, + const paddle::optional& out_linear_smooths_dtype, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const int nope_size, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + if (compute_dtype == "bf16") { + return {paddle::DataType::BFLOAT16}; + } else if (compute_dtype == "fp16") { + return {paddle::DataType::FLOAT16}; + } else { + PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); + } +} + +PD_BUILD_OP(multi_head_latent_attention) + .Inputs({"query", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "cu_seqlens_q", + "padding_offsets", + "cum_offsets", + "block_tables", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks", + "decoder_num_blocks_cpu", + "max_enc_len_this_time", + "max_dec_len_this_time", + "max_len_kv", + paddle::Optional("attn_mask"), + paddle::Optional("query_bias"), + paddle::Optional("query_out_scales"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("out_linear_shifts"), + paddle::Optional("out_linear_smooths")}) + .Outputs({"fmha_out"}) + .Attrs({"compute_type: std::string", + "cache_quant_type: std::string", + "nope_size: int", + "max_input_length: int", + "softmax_scale: float", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool"}) + .SetKernelFn(PD_KERNEL(MultiHeadLatentAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(MultiHeadLatentAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MultiHeadLatentAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/noaux_tc.cu b/custom_ops/gpu_ops/noaux_tc.cu new file mode 100644 index 000000000..a14f7443b --- /dev/null +++ b/custom_ops/gpu_ops/noaux_tc.cu @@ -0,0 +1,73 @@ + +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "helper.h" +#include "noauxtc_kernel.h" + +std::vector NoauxTc(paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + int n_group, + int topk_group, + int topk, + float routed_scaling_factor) { + auto input_shape = scores_with_bias.shape(); + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + auto input_type = scores_with_bias.dtype(); + auto place = scores_with_bias.place(); + auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place); + auto stream = scores_with_bias.stream(); + + invokeNoAuxTc(reinterpret_cast(scores.data()), + reinterpret_cast(group_scores.data()), + reinterpret_cast(scores_with_bias.data()), + num_tokens, + num_experts, + n_group, + topk_group, + topk, + routed_scaling_factor, + stream); + + return {scores}; +} + +std::vector NoauxTcInferDtype( + const paddle::DataType& scores_dtype, + const paddle::DataType& scores_with_bias_dtype) { + return {scores_dtype}; +} + +std::vector> NoauxTcInferShape( + const std::vector& scores_shape, + const std::vector& gating_output_shape) { + return {scores_shape}; +} + +PD_BUILD_OP(noaux_tc) + .Inputs({"scores", "scores_with_bias"}) + .Outputs({"output_tensor"}) + .Attrs({"n_group: int", + "topk_group: int", + "topk:int", + "routed_scaling_factor: float"}) + .SetKernelFn(PD_KERNEL(NoauxTc)) + .SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcInferDtype)); diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h new file mode 100644 index 000000000..bce305edc --- /dev/null +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -0,0 +1,551 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This code is partially inspired by and references the implementation found +// in NVIDIA TRTLLM. +#pragma once +#include +#include + +namespace cg = cooperative_groups; + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t BLOCK_SIZE = 512; +constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +namespace warp_topk { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) { + return 0; + } + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { + int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; + int64_t n = std::max(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); + return max(cache_topk, + round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); +} + +template +struct BitonicMerge { + // input should be a bitonic sequence, and sort it to be a monotonic sequence + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + if ((val > other_val && ascending) || (val < other_val && !ascending)) { + T tmp = val; + val = other_val; + other_val = tmp; + + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + + BitonicMerge::merge(val_arr, idx_arr); + BitonicMerge::merge(val_arr + arr_len / 2, + idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort(val_arr + arr_len / 2, + idx_arr + arr_len / 2); + BitonicMerge::merge(val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + + // ascending doesn't matter before merging since all we need is a bitonic + // sequence + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + + T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); + if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + + BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, T, idxT> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); + if (val != other && ((val > other) == (ascending != is_second))) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { +public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + } + } + + // load and merge k sorted values + __device__ void load_sorted(T const* __restrict__ in, + idxT const* __restrict__ in_idx, + idxT start) { + idxT idx = start + WARP_SIZE - 1 - lane_; + for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { + if (idx < start + k_) { + T t = in[idx]; + if (is_better_than(t, val_arr_[i])) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + + BitonicMerge::merge(val_arr_, idx_arr_); + } + + __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out[out_i] = val_arr_[i]; + out_idx[out_i] = idx_arr_[i]; + } + } + } + + __device__ void dumpIdx(idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out_idx[out_i] = idx_arr_[i]; + } + } + } + +protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + + int const lane_; + idxT const k_; + T const dummy_; + +}; // end class WarpSort + +template +class WarpSelect : public WarpSort { +public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; + + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T const* in, idxT start, idxT end) { + idxT const end_for_fullwarp = + round_up_to_multiple_of(end - start) + start; + for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { + T val = (i < end) ? in[i] : dummy_; + add(val, i); + } + } + + __device__ void add(T val, idxT idx) { + bool do_add = is_better_than(val, k_th_); + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); + if (mask == 0) { + return; + } + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + + // after done(), smem is used for merging results among warps + __syncthreads(); + } + +private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + + T& old = val_arr_[max_arr_len_ - 1]; + if (is_better_than(val, old)) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + + BitonicMerge::merge(val_arr_, idx_arr_); + + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + + T k_th_; + int const k_th_lane_; +}; // end class WarpSelect +} // namespace warp_topk + +template +__device__ void topk_with_k2(T* output, + T const* input, + cg::thread_block_tile<32> const& tile, + int32_t const lane_id, + int const num_experts_per_group) { + // Get the top2 per thread + T largest = cuda::std::numeric_limits::min(); + T second_largest = cuda::std::numeric_limits::min(); + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + T value = input[i]; + if (value > largest) { + second_largest = largest; + largest = value; + } else if (value > second_largest) { + second_largest = value; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + largest = input[i]; + } + } + + __syncwarp(); // Ensure all threads have valid data before reduction + // Get the top2 warpwise + T max1 = cg::reduce(tile, largest, cg::greater()); + + T max2 = max1; + bool equal_to_max1 = (max1 == largest); + int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); + + if (count_max1 == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + + if (lane_id == 0) { + *output = max1 + max2; + } +} + +template +__global__ void topk_with_k2_kernel(T* output, + T* input, + int64_t const num_tokens, + int64_t const num_cases, + int64_t const n_group, + int64_t const num_experts_per_group) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + + int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; + if (case_id < num_cases) { + input += case_id * num_experts_per_group; + output += case_id; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + } +} + +template +__global__ void group_idx_and_topk_idx_kernel( + T* scores, + T const* group_scores, + T* scores_with_bias, + int64_t const num_tokens, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + int64_t const num_experts, + int64_t const num_experts_per_group, + double routed_scaling_factor) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf) + warp_id * topk; + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + + T value = cuda::std::numeric_limits::min(); + T topk_group_value = cuda::std::numeric_limits::min(); + int32_t num_equalto_topkth_group; + + if ((n_group > topk_group) && (case_id < num_tokens)) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group) { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = cuda::std::numeric_limits::min(); + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = __popc(__ballot_sync( + FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, cuda::std::numeric_limits::min()); + + int count_equalto_topkth_group = 0; + if (case_id < num_tokens) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = i < num_experts_per_group + ? scores_with_bias[offset + i] + : cuda::std::numeric_limits::min(); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = i < topk ? scores[s_topk_idx[i]] + : 0.0f; // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += reduce(tile, value, cg::plus()); + } + } + + __syncthreads(); + if (case_id < num_tokens) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores[i] = 0; + } + } + __threadfence(); + __syncthreads(); + + if (case_id < num_tokens) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value = s_topk_value[i] / topk_sum * routed_scaling_factor; + scores[s_topk_idx[i]] = value; + } + } +} + +template +void invokeNoAuxTc(T* scores, + T* group_scores, + T* scores_with_bias, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + double const routed_scaling_factor, + cudaStream_t const stream) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + topk_with_k2_kernel<<>>( + group_scores, + scores_with_bias, + num_tokens, + num_cases, + n_group, + num_experts / n_group); + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + + group_idx_and_topk_idx_kernel<<>>(scores, + group_scores, + scores_with_bias, + num_tokens, + n_group, + topk_group, + topk, + num_experts, + num_experts / n_group, + routed_scaling_factor); +} + +#define INSTANTIATE_NOAUX_TC(T) \ + template void invokeNoAuxTc(T * scores, \ + T * group_scores, \ + T * scores_with_bias, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + double const routed_scaling_factor, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float); diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu index f195403a5..9a16d4d36 100644 --- a/custom_ops/gpu_ops/per_token_quant_fp8.cu +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -50,11 +50,13 @@ __global__ void quant_per_token_per_block(const T *input, max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); } // get max value per warp - max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread); - max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread); - max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread); - max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread); - max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread); + // broadcast max_value + max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); max_value_thread = max(max_value_thread, epsilon); float scale_to_store = max_value_thread / MAX_VALUE; // quant diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 64136a2a9..dd26d6e90 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -267,6 +267,9 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/text_image_index_out.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/sample_kernels/rejection_top_p_sampling.cu", + "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/fused_rotary_position_encoding.cu", + "gpu_ops/noaux_tc.cu", ] # pd_disaggregation @@ -376,6 +379,8 @@ elif paddle.is_compiled_with_cuda(): # append_attention sources += ["gpu_ops/append_attention.cu"] sources += find_end_files("gpu_ops/append_attn", ".cu") + # mla + sources += ["gpu_ops/multi_head_latent_attention.cu"] # gemm_dequant sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"] # speculate_decoding @@ -441,6 +446,10 @@ elif paddle.is_compiled_with_cuda(): sources += find_end_files(fp8_auto_gen_directory, ".cu") + if cc >= 90 and nvcc_version >= 12.0: + # Hopper optmized mla + sources += find_end_files("gpu_ops/mla_attn", ".cu") + setup( name="fastdeploy_ops", ext_modules=CUDAExtension( diff --git a/fastdeploy/__init__.py b/fastdeploy/__init__.py index f44e03840..b87fb002c 100644 --- a/fastdeploy/__init__.py +++ b/fastdeploy/__init__.py @@ -15,6 +15,8 @@ """ import os +import subprocess +import sys # suppress warning log from paddlepaddle os.environ["GLOG_minloglevel"] = "2" @@ -30,3 +32,48 @@ try: use_triton_in_paddle.make_triton_compatible_with_paddle() except ImportError: pass +# TODO(tangbinhan): remove this code + + +def _patch_fastsafetensors(): + try: + file_path = subprocess.check_output([ + sys.executable, "-c", "import fastsafetensors, os; \ + print(os.path.join(os.path.dirname(fastsafetensors.__file__), \ + 'frameworks', '_paddle.py'))" + ]).decode().strip() + + with open(file_path, 'r') as f: + content = f.read() + if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content: + return + + modified = False + if "DType.U16: DType.BF16," not in content: + lines = content.splitlines() + new_lines = [] + inside_block = False + for line in lines: + new_lines.append(line) + if 'need_workaround_dtypes: Dict[DType, DType] = {' in line: + inside_block = True + elif inside_block and '}' in line: + new_lines.insert(-1, ' DType.U16: DType.BF16,') + inside_block = False + modified = True + content = "\n".join(new_lines) + + if "DType.I8: paddle.uint8," in content: + content = content.replace("DType.I8: paddle.uint8,", + "DType.U8: paddle.uint8,") + modified = True + + if modified: + with open(file_path, 'w') as f: + f.write(content + "\n") + + except Exception as e: + print(f"Failed to patch fastsafetensors: {e}") + + +_patch_fastsafetensors() diff --git a/fastdeploy/config.py b/fastdeploy/config.py index d2424d42a..715df04eb 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -18,7 +18,7 @@ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum -from typing import Optional +from typing import Optional, Literal from paddleformers.transformers.configuration_utils import PretrainedConfig @@ -51,7 +51,6 @@ class ModelConfig(PretrainedConfig): top_p = 0.0 temperature = 1.0 rope_theta = 10000.0 - rope_scaling = None penalty_score = 1.0 frequency_score = 0.0 presence_score = 0.0 @@ -142,6 +141,7 @@ class MoEConfig: moe_num_shared_experts = (0, ) moe_layer_start_index = 0 moe_layer_end_index = None + moe_use_aux_free: bool = False num_max_dispatch_tokens_per_rank = 256 im_patch_id = ( 100295 # multimodality, TODO(liuyuanle): read from config.json @@ -163,7 +163,6 @@ class ParallelConfig: # The embedding weight distributed on your gpu cards is divided by row or column. # Defaults to False means divide by row. When vocab_size can not be divided by world_size # but hidden_size can, we can consider split embedding weight by column. - column_cut = False # (bool, optional) """ From old wersion worker args TODO(gongshaotian): Reclassify @@ -194,18 +193,13 @@ class ParallelConfig: engine_pid: Optional[int] = None # Do profile or not do_profile: bool = False - # Dynamic load weight or not - dynamic_load_weight: bool = False # pad_token_id: int = -1 # eos_tokens_lens: int = 2 # Enable chunked prefill enable_chunked_prefill: str = "store_true" - """ - - APPEND_ATTN: - """ - attention_backend: str = "APPEND_ATTN" + # max_num_batched_tokens: int = 2048 # enable prefix cache enable_prefix_caching = None @@ -354,9 +348,27 @@ class GraphOptimizationConfig: @dataclass class LoadConfig: """ - Configuration for loading parameter + Configuration for dynamic weight loading strategies + + Attributes: + dynamic_load_weight: Whether to enable dynamic weight loading + load_strategy: Specifies the weight loading method when enabled: + - 'ipc': Real-time IPC streaming with automatic resharding + - 'ipc_no_reshard': Real-time IPC streaming without weight process + - 'ipc_snapshot': Load from disk snapshot of IPC weights + - 'meta': provide RL traing worker, no_weights_load + - None: No dynamic loading """ - pass + use_fastsafetensor: bool = False + dynamic_load_weight: bool = False + load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None + + def __post_init__(self): + if self.load_strategy is not None and not self.dynamic_load_weight: + raise ValueError("Load strategy requires dynamic_load_weight=True") + + if self.dynamic_load_weight and self.load_strategy is None: + raise ValueError("Must specify load_strategy when dynamic_load_weight is True") @dataclass @@ -392,7 +404,7 @@ class FDConfig: init=True) # type: ignore device_config: DeviceConfig = field(default=None, init=True) # type: ignore - load_config: LoadConfig = field(default=None, init=True) # type: ignore + load_config: LoadConfig = field(default=None, init=True) quant_config: Optional[QuantConfigBase] = None graph_opt_config: Optional[GraphOptimizationConfig] = None moe_config: MoEConfig = field(default=None, init=True) # type: ignore diff --git a/fastdeploy/demo/offline_disaggregated_demo.py b/fastdeploy/demo/offline_disaggregated_demo.py index 67ee214a2..82831649c 100644 --- a/fastdeploy/demo/offline_disaggregated_demo.py +++ b/fastdeploy/demo/offline_disaggregated_demo.py @@ -16,48 +16,54 @@ import time import os -import subprocess -import signal +import multiprocessing from fastdeploy.entrypoints.llm import LLM from fastdeploy.engine.sampling_params import SamplingParams -model_name_or_path = "./models/eb45t02/" + +model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle" -prefill_cmd = (f"FD_LOG_DIR=log_prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python fastdeploy.entrypoints.openai.api_server.py" - + f" --model {model_name_or_path} --port 9811" - + f" --splitwise-role prefill --tensor-parallel-size 4" - + f" --engine-worker-queue-port 6676 --cache-queue-port 55663") +def start_decode(model_name_or_path): + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + os.environ["FD_LOG_DIR"] = "log_decode" + llm_decode = LLM( + model=model_name_or_path, + tensor_parallel_size=1, + splitwise_role="decode", + engine_worker_queue_port=6678, + innode_prefill_ports=[6676], + cache_queue_port=55668 + ) + return llm_decode -prefill_instance = subprocess.Popen( - prefill_cmd, - stdout=subprocess.PIPE, - shell=True, - preexec_fn=os.setsid, - ) +def start_prefill(model_name_or_path): + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + os.environ["FD_LOG_DIR"] = "log_prefill" + llm_prefill = LLM( + model=model_name_or_path, + tensor_parallel_size=1, + splitwise_role="prefill", + engine_worker_queue_port=6677, + cache_queue_port=55667, + ) +def main(): + prefill = multiprocessing.Process( + target=start_prefill, + args=(model_name_or_path,)).start() + time.sleep(10) + llm_decode = start_decode(model_name_or_path) + + output = llm_decode.generate(prompts=["who are you?", "what can you do?"], use_tqdm=True) + print(output) + + decode.join() -# # 超参设置 -os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7" -os.environ["FD_LOG_DIR"] = "log_decode" -sampling_params = SamplingParams(temperature=0.1, max_tokens=30) -llm_decode = LLM( - model=model_name_or_path, - tensor_parallel_size=4, - splitwise_role="decode", - engine_worker_queue_port=6678, - innode_prefill_ports=[6676], - cache_queue_port=55668 - ) - - -output = llm_decode.generate(prompts=["who are you?", "what can you do?"], use_tqdm=True) -print(output) - - -os.killpg(prefill_instance.pid, signal.SIGTERM) \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 047b42b2f..fefca65c1 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -87,10 +87,14 @@ class EngineArgs: """ Configuration for speculative execution. """ - dynamic_load_weight: int = 0 + dynamic_load_weight: bool = False """ dynamic load weight """ + load_strategy: str = "meta" + """ + dynamic load weight strategy + """ quantization: str = None guided_decoding_backend: str = "off" """ @@ -364,13 +368,16 @@ class EngineArgs: type=json.loads, default=EngineArgs.speculative_config, help="Configuration for speculative execution.") - model_group.add_argument( "--dynamic-load-weight", - type=int, + action='store_true', default=EngineArgs.dynamic_load_weight, help="Flag to indicate whether to load weight dynamically.") - + model_group.add_argument( + "--load-strategy", + type=str, + default=EngineArgs.load_strategy, + help="Flag to dynamic load strategy.") model_group.add_argument("--engine-worker-queue-port", type=int, default=EngineArgs.engine_worker_queue_port, @@ -383,6 +390,7 @@ class EngineArgs: "default is None. The priority of this configuration "\ "is lower than that of the config file. " \ "More complex quantization methods need to be configured via the config file.") + model_group.add_argument( "--enable-static-graph-inference", action='store_true', @@ -668,8 +676,9 @@ class EngineArgs: """ return ModelConfig(model_name_or_path=self.model, config_json_file=self.model_config_name, + quantization=self.quantization, dynamic_load_weight=self.dynamic_load_weight, - quantization=self.quantization) + load_strategy=self.load_strategy) def create_cache_config(self, model_cfg) -> CacheConfig: """ @@ -749,6 +758,9 @@ class EngineArgs: speculative_cfg = self.create_speculative_config() + assert not (self.use_cudagraph and self.enable_prefix_caching), \ + "Prefix caching cannot be used with CUDA graph" + return Config( model_name_or_path=self.model, model_config=model_cfg, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 6b8d6f3a4..a4efb0c61 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -41,7 +41,8 @@ class ModelConfig: def __init__(self, model_name_or_path: str, config_json_file: str = "config.json", - dynamic_load_weight: int = 0, + dynamic_load_weight: bool = False, + load_strategy: str="meta", quantization: str = None, download_dir: Optional[str] = None): """ @@ -55,6 +56,7 @@ class ModelConfig: self.model_dir = model_name_or_path self.is_unified_ckpt = check_unified_ckpt(self.model_dir) self.dynamic_load_weight = dynamic_load_weight + self.load_strategy = load_strategy self.quantization = quantization config_file = os.path.join(model_name_or_path, config_json_file) @@ -584,12 +586,10 @@ class Config: self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace - if self.innode_prefill_ports is not None: if not isinstance(self.innode_prefill_ports, list): ports = str(self.innode_prefill_ports).split(',') self.innode_prefill_ports = [int(port) for port in ports] - assert self.splitwise_role in ["mixed", "prefill", "decode"] @@ -728,7 +728,7 @@ class Config: ), "XPU currently do not support guided_decoding" try: - pass + import xgrammar except Exception as e: raise Exception( f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}" diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 162c89078..dcac96a5b 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -286,6 +286,8 @@ class LLMEngine(object): while self.running: try: results = self.scheduler.get_results() + if len(results) == 0: + time.sleep(0.001) for request_id, contents in results.items(): for result in contents: self.zmq_server.send_multipart(request_id, result) @@ -444,8 +446,8 @@ class LLMEngine(object): enable_thinking = None if kwargs is not None: enable_thinking = kwargs.get("enable_thinking", None) - request = self.data_processor.process_request(request, - self.cfg.max_model_len, enable_thinking=enable_thinking) + request = self.data_processor.process_request( + request, self.cfg.max_model_len, enable_thinking=enable_thinking) request.prompt_token_ids_len = len(request.prompt_token_ids) input_ids_len = request.prompt_token_ids_len request.set( @@ -453,7 +455,8 @@ class LLMEngine(object): min(self.cfg.max_model_len - input_ids_len, request.get("max_tokens"))) if request.get("reasoning_max_tokens") is None: - default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1) + default_reasoning_max_tokens = max( + int(request.get("max_tokens") * 0.8), 1) request.set("reasoning_max_tokens", default_reasoning_max_tokens) min_tokens = request.get("min_tokens") if input_ids_len + min_tokens >= self.cfg.max_model_len: @@ -963,8 +966,8 @@ class LLMEngine(object): "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", "FLAGS_use_append_attn": 1, "NCCL_ALGO": "Ring", - "FLAGS_hardamard_moe_block_size": 128, "FLAGS_max_partition_size": 32768, + "FLAGS_hardamard_moe_block_size": 128, } # environment variables needed by Dy2St variables.update({ @@ -1017,6 +1020,12 @@ class LLMEngine(object): worker_path = "../worker/vl_worker_process.py" py_script = os.path.join(current_dir_path, worker_path) + ori_vocab_size = ( + len(self.data_processor.tokenizer.sp_model) + if hasattr(self.data_processor.tokenizer, 'sp_model') + else len(self.data_processor.tokenizer.vocab) + ) + arguments = ( f" --nnodes {str(self.cfg.nnode)}" f" --devices {self.cfg.device_ids} {py_script}" @@ -1037,13 +1046,14 @@ class LLMEngine(object): f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" f" --quantization {self.cfg.model_config.quantization}" - f" --ori_vocab_size {len(self.data_processor.tokenizer)}" + f" --ori_vocab_size {ori_vocab_size}" f" --speculative_method {self.cfg.speculative_config.method}" f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}" f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}" f" --speculative_model_quantization {self.cfg.speculative_config.quantization}" f" --max_capture_batch_size {self.cfg.max_capture_batch_size}" - f" --guided_decoding_backend {self.cfg.guided_decoding_backend}") + f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" + f" --load_strategy {self.cfg.model_config.load_strategy}") worker_append_flag = { "enable_expert_parallel": @@ -1188,8 +1198,9 @@ class LLMEngine(object): line = line.decode('utf-8', errors='ignore') if self.worker_init_status.get("finished", False): break - if match := re.search(r'Loading checkpoint shards:\s*(\d+)', - line): + if match := re.search( + r'Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)', + line): self.worker_init_status["weight_loadding"] = eval( match.group(1)) * 1.0 / 100 elif (match := re.search(r'Start load layer (\d+)', diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 23c597b05..6d75b6b57 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -220,6 +220,9 @@ class OpenAIServingChat: choice.finish_reason = "tool_calls" else: choice.finish_reason = "length" + + if res.get("error_msg") is not None and "Recover" in res["error_msg"]: + choice.finish_reason = "length" if request.metadata is not None and request.metadata.get("training", False) and delta_text != "": choice.delta.token_ids = output["token_ids"] @@ -335,6 +338,9 @@ class OpenAIServingChat: choice.finish_reason = "tool_calls" else: choice.finish_reason = "length" + + if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]: + choice.finish_reason = "length" choices.append(choice) num_prompt_tokens = len(prompt_token_ids) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index da276d8f5..8ef8a5149 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -82,13 +82,21 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), + # Set whether to disable recompute the request when the KV cache is full. + "FD_DISABLED_RECOVER": + lambda: os.getenv("FD_DISABLED_RECOVER", "0"), + # Set triton kernel JIT compilation directory. "FD_TRITON_KERNEL_CACHE_DIR": lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None), # Whether transition from standalone PD decoupling to centralized inference "FD_PD_CHANGEABLE": - lambda: os.getenv("FD_PD_CHANGEABLE", "1"), + lambda: os.getenv("FD_PD_CHANGEABLE", "0"), + + # Whether to use fastsafetensor load weight (0 or 1) + "FD_USE_FASTSAFETENSOR": + lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"), } diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index f8d4976d5..51dbed766 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -27,6 +27,7 @@ from fastdeploy.input.text_processor import BaseDataProcessor _SAMPLING_EPS = 1e-5 + class ErnieProcessor(BaseDataProcessor): """ 初始化模型实例。 @@ -160,6 +161,7 @@ class ErnieProcessor(BaseDataProcessor): if request.get('prompt'): prompt = request.get('prompt') prompt = prompt[0] if isinstance(prompt, list) else prompt + tokens = self.tokenizer.tokenize(prompt) token_ids = self.tokenizer.convert_tokens_to_ids(tokens) request['prompt_token_ids'] = token_ids diff --git a/fastdeploy/input/ernie_tokenizer.py b/fastdeploy/input/ernie_tokenizer.py index 13a3c1a79..d6392d5e2 100644 --- a/fastdeploy/input/ernie_tokenizer.py +++ b/fastdeploy/input/ernie_tokenizer.py @@ -82,6 +82,8 @@ class ErnieBotTokenizer(PretrainedTokenizer): self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(vocab_file) + # pre-process map-type all spec token for decode accelerate. + self.all_spec_tok = set(self.all_special_tokens) @property def space_token(self): @@ -143,7 +145,7 @@ class ErnieBotTokenizer(PretrainedTokenizer): # prev_is_special = False for token in tokens: # make sure that special tokens are not decoded using sentencepiece model - if token in self.all_special_tokens: + if token in self.all_spec_tok: # if not prev_is_special: # out_string += " " out_string += self.sp_model.decode(current_sub_tokens) + token @@ -216,7 +218,7 @@ class ErnieBotTokenizer(PretrainedTokenizer): if hasattr(self, "do_lower_case") and self.do_lower_case: # convert non-special tokens to lowercase escaped_special_toks = [ - re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens) + re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok) ] pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index 4f32a2936..ae2dc1f29 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -25,6 +25,7 @@ from fastdeploy.utils import data_processor_logger _SAMPLING_EPS = 1e-5 + class BaseDataProcessor(ABC): """base class for data processor""" diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 22de36bfe..afbf916a5 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -16,10 +16,12 @@ from .attention import Attention from .append_attn_backend import AppendAttentionBackend from .attention_selecter import get_attention_backend from .base_attention_backend import AttentionBackend +from .mla_attention_backend import MLAAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend from .xpu_attn_backend import XPUAttentionBackend __all__ = [ "Attention", "AttentionBackend", "PaddleNativeAttnBackend", - "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend" + "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend", + "MLAAttentionBackend" ] diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 3c6446cdb..eb82e0bf9 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -187,6 +187,8 @@ class AppendAttentionBackend(AttentionBackend): k: paddle.Tensor, v: paddle.Tensor, qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, layer: Attention, forward_meta: ForwardMeta, ) -> paddle.Tensor: diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index ea06feff8..3f676f031 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -111,6 +111,8 @@ class Attention(nn.Layer): k: paddle.Tensor = None, v: paddle.Tensor = None, qkv: paddle.Tensor = None, + compressed_kv: paddle.Tensor = None, + k_pe: paddle.Tensor = None, forward_meta: ForwardMeta = None, ) -> paddle.Tensor: """ @@ -120,12 +122,16 @@ class Attention(nn.Layer): k: the key tensor v: the value tensor forward_meta: the forward meta data + compressed_kv: optional compressed key-value cache (for MLA) + k_pe: optional key positional encoding (for MLA) """ return forward_meta.attn_backend.forward( q, k, v, qkv, + compressed_kv, + k_pe, self, forward_meta, ) diff --git a/fastdeploy/model_executor/layers/attention/attention_selecter.py b/fastdeploy/model_executor/layers/attention/attention_selecter.py index a20adfaaa..3db03b188 100644 --- a/fastdeploy/model_executor/layers/attention/attention_selecter.py +++ b/fastdeploy/model_executor/layers/attention/attention_selecter.py @@ -16,6 +16,7 @@ from functools import cache +from fastdeploy import envs from fastdeploy.platforms import _Backend, current_platform from fastdeploy.utils import resolve_obj_from_strname @@ -40,6 +41,7 @@ def _get_attn_backend(selected_backend: str) -> object: return resolve_obj_from_strname(attention_cls) -def get_attention_backend(selected_backend): - """Selects which attention backend .""" - return _get_attn_backend(selected_backend) +def get_attention_backend() -> object: + """Selects which attention backend.""" + attention_backend = envs.FD_ATTENTION_BACKEND + return _get_attn_backend(attention_backend) diff --git a/fastdeploy/model_executor/layers/attention/base_attention_backend.py b/fastdeploy/model_executor/layers/attention/base_attention_backend.py index eb971cb2b..02d1d65db 100644 --- a/fastdeploy/model_executor/layers/attention/base_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/base_attention_backend.py @@ -46,6 +46,8 @@ class AttentionBackend(ABC): k: paddle.Tensor, v: paddle.Tensor, qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, layer: paddle.nn.Layer, forward_meta: ForwardMeta, ) -> paddle.Tensor: @@ -56,6 +58,8 @@ class AttentionBackend(ABC): k: The key tensor. v: The value tensor. layer: The layer that will be used for the forward. + compressed_kv: optional compressed key-value cache (for MLA) + k_pe: optional key positional encoding (for MLA) forward_meta: The forward metadata. """ if forward_meta.forward_mode.is_mixed(): @@ -64,6 +68,8 @@ class AttentionBackend(ABC): k, v, qkv, + compressed_kv, + k_pe, layer, forward_meta, ) @@ -73,6 +79,8 @@ class AttentionBackend(ABC): k, v, qkv, + compressed_kv, + k_pe, layer, forward_meta, ) @@ -82,6 +90,8 @@ class AttentionBackend(ABC): k, v, qkv, + compressed_kv, + k_pe, layer, forward_meta, ) @@ -92,6 +102,8 @@ class AttentionBackend(ABC): k: paddle.Tensor, v: paddle.Tensor, qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, layer: paddle.nn.Layer, forward_meta: ForwardMeta, ) -> paddle.Tensor: @@ -104,6 +116,8 @@ class AttentionBackend(ABC): k: paddle.Tensor, v: paddle.Tensor, qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, layer: paddle.nn.Layer, forward_meta: ForwardMeta, ) -> paddle.Tensor: @@ -116,6 +130,8 @@ class AttentionBackend(ABC): k: paddle.Tensor, v: paddle.Tensor, qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, layer: paddle.nn.Layer, forward_meta: ForwardMeta, ) -> paddle.Tensor: diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py new file mode 100644 index 000000000..52489bd5f --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -0,0 +1,490 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import math +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Tuple + +import paddle +from paddle.nn.functional.flash_attention import flash_attn_unpadded + +from fastdeploy.model_executor.layers.attention.ops import ( + get_block_shape_and_split_kv_block, init_signal_layerwise, + open_shm_and_get_meta_signal) +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache, + multi_head_latent_attention, + prefill_mla_write_cache) + +if TYPE_CHECKING: + from paddle._typing.dtype_like import _DTypeLiteral + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, AttentionMetadata) +from fastdeploy.worker.forward_meta import ForwardMeta + + +def yarn_get_mscale(scale=1, mscale=1): + """ + """ + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +@dataclass +class MLAAttentionMetadata(AttentionMetadata): + """ + MLAAttentionMetadata for Multi-Layer Attention + """ + max_len_kv: paddle.Tensor = None + set_max_lengths: int = -1 + encoder_batch_ids: paddle.Tensor = None + encoder_tile_ids_per_batch: paddle.Tensor = None + encoder_num_blocks: paddle.Tensor = None + kv_batch_ids: paddle.Tensor = None + kv_tile_ids_per_batch: paddle.Tensor = None + kv_num_blocks: paddle.Tensor = None + decoder_batch_ids: paddle.Tensor = None + decoder_tile_ids_per_batch: paddle.Tensor = None + decoder_num_blocks: paddle.Tensor = None + + _dtype: _DTypeLiteral = paddle.bfloat16 + encoder_max_partition_size: int = 32768 + max_partition_size: int = 32768 + block_tables: Optional[paddle.Tensor] = None + rotary_embs: Optional[paddle.Tensor] = None + attn_mask: Optional[paddle.Tensor] = None + encoder_block_shape_q: Optional[paddle.Tensor] = None + decoder_block_shape_q: Optional[paddle.Tensor] = None + _fuse_kernel_compute_dtype: str = "bf16" + + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) + + +class MLAAttentionBackend(AttentionBackend): + """ + MLA Attention Backend implementation. + """ + + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, + head_dim: int) -> None: + """ + MLAAttentionBackend __init__ + """ + super().__init__() + self.attention_metadata: MLAAttentionMetadata = None + + # 基础配置 + self.block_size: int = fd_config.parallel_config.block_size + self.max_seq_len: int = fd_config.parallel_config.max_model_len + self.rope_theta: float = (10000.0 + if fd_config.model_config.rope_theta is None + else fd_config.model_config.rope_theta) + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.causal: bool = getattr(fd_config.model_config, "causal", True) + self.speculative_method: str = fd_config.speculative_config.method + self.use_speculate: bool = self.speculative_method is not None + self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens + self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.rank: int = fd_config.parallel_config.tensor_parallel_rank + + self.kv_num_heads: int = kv_num_heads + self.num_heads: int = num_heads + self.head_dim: int = fd_config.model_config.head_dim + self.num_layers: int = fd_config.model_config.num_layers + + # For Multi Head Latent Attention + self.kv_lora_rank: int = fd_config.model_config.deepseekv3.kv_lora_rank + self.qk_rope_head_dim: int = fd_config.model_config.deepseekv3.qk_rope_head_dim + self.qk_head_dim: int = fd_config.model_config.deepseekv3.qk_nope_head_dim \ + + fd_config.model_config.deepseekv3.qk_rope_head_dim + self.attn_softmax_scale: float = self.qk_head_dim**-0.5 + if fd_config.model_config.deepseekv3.rope_scaling: + mscale_all_dim = fd_config.model_config.deepseekv3.rope_scaling.get( + "mscale_all_dim", False) # 1.0 + scaling_factor = fd_config.model_config.deepseekv3.rope_scaling[ + "factor"] # 40 + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale + + # pd_disaggregation + self.use_pd_disaggregation: int = int( + os.getenv("FLAGS_use_pd_disaggregation", 0)) + self.start_layer_index: int = fd_config.model_config.start_layer_index + self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) + if self.device_id is None: + self.device_id = self.rank + else: + self.device_id = self.device_id.split(",")[self.rank] + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attention metadata hence all layers in the forward pass can reuse it.""" + metadata = MLAAttentionMetadata() + metadata.encoder_block_shape_q = 64 + metadata.decoder_block_shape_q = 16 + metadata.max_partition_size = 32768 + metadata.encoder_max_partition_size = self.max_seq_len + metadata._dtype = paddle.get_default_dtype() + if metadata._dtype == "bfloat16": + metadata._fuse_kernel_compute_dtype = "bf16" + elif metadata._dtype == "float16": + metadata._fuse_kernel_compute_dtype = "fp16" + elif metadata._dtype == "float32": + metadata._fuse_kernel_compute_dtype = "fp32" + + metadata.block_tables = forward_meta.block_tables + metadata.rotary_embs = forward_meta.rotary_embs + metadata.attn_mask = forward_meta.attn_mask + metadata.pre_caches_length = forward_meta.pre_caches_length + + ( + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + metadata.decoder_batch_ids, + metadata.decoder_tile_ids_per_batch, + metadata.decoder_num_blocks, + metadata.max_len_kv, + metadata.set_max_lengths, + ) = get_block_shape_and_split_kv_block( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cum_offsets, + metadata.encoder_block_shape_q, + metadata.decoder_block_shape_q, + self.num_heads // self.kv_num_heads, + self.block_size, + self.speculate_max_draft_token_num + 1, + ) + + # MLA + metadata.max_enc_len_this_time = metadata.set_max_lengths[1] + metadata.max_dec_len_this_time = metadata.set_max_lengths[2] + + # pd_disaggregation + metadata.kv_signal_data_list = [None] * self.num_layers + if self.use_pd_disaggregation: + metadata.kv_signal_metadata = open_shm_and_get_meta_signal( + self.rank, int(self.device_id), self.keep_pd_step_flag) + + self.attention_metadata: AttentionMetadata = metadata + + def get_attntion_meta(self) -> AttentionMetadata: + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape(self, + max_num_blocks: int) -> Tuple[int, int, int, int]: + """ + Calculate kv cache shape for MLA + """ + return (max_num_blocks, 1, self.block_size, + self.kv_lora_rank + self.qk_rope_head_dim) + + def forward_extend( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + Prefill阶段的前向传播 + """ + metadata = self.attention_metadata + + if self.use_pd_disaggregation: + metadata.kv_signal_data_list[ + layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index) + + latent_cache = forward_meta.caches[layer.layer_id] if hasattr( + forward_meta, 'caches') else None + + # 写入缓存 + prefill_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + "none", + getattr(forward_meta, 'max_input_length', -1), + ) + + # Flash注意力计算 + fmha_out = flash_attn_unpadded( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_k, + metadata.max_enc_len_this_time, + metadata.max_enc_len_this_time, + self.attn_softmax_scale, + causal=True, + training=False, + )[0] + + return fmha_out + + def forward_decode( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + Decode阶段的前向传播 + """ + metadata = self.attention_metadata + + if self.use_pd_disaggregation: + metadata.kv_signal_data_list[ + layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index) + + latent_cache = forward_meta.caches[layer.layer_id] if hasattr( + forward_meta, 'caches') else None + + # 获取推测解码参数 + speculate_decoder = self.speculative_method is not None + speculate_max_tokens = self.speculate_max_draft_token_num + + # 写入缓存 + decode_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_encoder, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + "none", + self.max_seq_len, + speculate_decoder, + ) + + # 多头潜在注意力计算 + fmha_out = multi_head_latent_attention( + q, + latent_cache, + latent_cache, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + metadata.decoder_batch_ids, + metadata.decoder_tile_ids_per_batch, + metadata.decoder_num_blocks, + metadata. + decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu + metadata.max_enc_len_this_time, + metadata.max_dec_len_this_time, + metadata.max_len_kv, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # out_shifts + None, # out_smooths + metadata._fuse_kernel_compute_dtype, + "none", # cache_quant_type + self.kv_lora_rank, + self.max_seq_len, + self.attn_softmax_scale, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + 0.0, # out_linear_in_scale + speculate_max_tokens, + True, # causal + speculate_decoder, + ) + + return fmha_out + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + Mixed模式的前向传播 + """ + metadata = self.attention_metadata + speculate_decoder = self.speculative_method is not None + speculate_max_tokens = self.speculate_max_draft_token_num + + decode_stage = forward_meta.is_decode_batch + prefill_stage = not (forward_meta.is_decode_batch) + + if self.use_pd_disaggregation: + metadata.kv_signal_data_list[ + layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index) + + latent_cache = forward_meta.caches[layer.layer_id] if hasattr( + forward_meta, 'caches') else None + + if prefill_stage: + # 写入缓存 + prefill_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + "none", + self.max_seq_len, + ) + + # FA + fmha_out = flash_attn_unpadded( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_k, + metadata.max_enc_len_this_time, + metadata.max_enc_len_this_time, + self.attn_softmax_scale, + causal=True, + training=False, + )[0] + + return fmha_out + + # Decode + if decode_stage: + # mla写入缓存 + decode_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_encoder, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + "none", + self.max_seq_len, + speculate_decoder, + ) + + # 多头潜在注意力计算 + fmha_out = multi_head_latent_attention( + q, + latent_cache, + latent_cache, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + metadata.decoder_batch_ids, + metadata.decoder_tile_ids_per_batch, + metadata.decoder_num_blocks, + metadata. + decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu + metadata.max_enc_len_this_time, + metadata.max_dec_len_this_time, + metadata.max_len_kv, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # out_shifts + None, # out_smooths + metadata._fuse_kernel_compute_dtype, + "none", # cache_quant_type + self.kv_lora_rank, + self.max_seq_len, + self.attn_softmax_scale, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + 0.0, # out_linear_in_scale + speculate_max_tokens, + True, # causal + speculate_decoder, + ) + + return fmha_out diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index f3fcdf8d5..2ea49f299 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -149,6 +149,8 @@ class XPUAttentionBackend(AttentionBackend): k: paddle.Tensor, v: paddle.Tensor, qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, layer: Attention, forward_meta: ForwardMeta, ) -> paddle.Tensor: diff --git a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py index ceb445a88..db264b07a 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py @@ -41,16 +41,12 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): """ Create weights for linear layer on XPU """ + # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. + linear_weight_scale_shape = [layer.linear_weight_shape[1]] layer.linear_weight_shape.reverse() if self.quant_config.name() == "weight_only_int4": layer.linear_weight_shape[0] //= 2 layer.weight_dtype = "int8" - linear_weight_scale_shape = [layer.embed_dim] - if hasattr(layer, "linear_weight_shape"): - if isinstance(layer.linear_weight_shape, list): - layer_weight_shape = layer.linear_weight_shape - linear_weight_scale_shape = layer_weight_shape[:1] - layer.linear_weight_scale = layer.create_parameter( shape=linear_weight_scale_shape, dtype="float32", diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index 86fb06c8b..bc67cb133 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -14,10 +14,15 @@ # limitations under the License. """ +from typing import Dict + +import numpy as np import paddle from paddle import nn from paddle.distributed import fleet +from fastdeploy.config import FDConfig + from .utils import get_tensor @@ -28,12 +33,12 @@ class VocabParallelEmbedding(nn.Layer): def __init__( self, - fd_config, - num_embeddings, - embedding_dim=768, - params_dtype="bfloat16", + fd_config: FDConfig, + num_embeddings: int, + embedding_dim: int = 768, + params_dtype: str = "bfloat16", prefix="", - ): + ) -> None: """ Initialize the VocabParallelEmbedding layer for the model. @@ -41,28 +46,28 @@ class VocabParallelEmbedding(nn.Layer): fd_config (FDConfig): Arguments related to inference, containing attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim, num_attention_heads, and ffn_hidden_size. - num_embeddings : vocabulary size. - embedding_dim : size of hidden state. - params_dtype : data type of parameters. - prefix (str): Unique name of the layer, used for naming internal attributes, - you can give it any name you like. + num_embeddings (int) : vocabulary size. + embedding_dim (int) : size of hidden state. + params_dtype (str) : data type of parameters. + prefix (str): The name of current layer. Defaults to "". """ super().__init__() self.fd_config = fd_config hcg = fleet.get_hybrid_communicate_group() - self.mp_rank = hcg.get_model_parallel_rank() - self.column_cut = fd_config.parallel_config.column_cut - self.world_size = hcg.get_model_parallel_world_size() - self.ring_id = hcg.get_model_parallel_group().id - self.use_rope = fd_config.model_config.use_rope - self.rope_head_dim = fd_config.model_config.rope_head_dim - self.use_ep = fd_config.parallel_config.use_ep - self.hidden_dropout_prob = fd_config.model_config.hidden_dropout_prob - self.initializer_range = fd_config.model_config.initializer_range - self.sequence_parallel = fd_config.parallel_config.sequence_parallel - self.max_position_embeddings = fd_config.model_config.max_position_embeddings - self.freeze_embedding = fd_config.model_config.freeze_embedding - self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings + self.mp_rank: int = hcg.get_model_parallel_rank() + self.column_cut = False + self.world_size: int = hcg.get_model_parallel_world_size() + self.ring_id: int = hcg.get_model_parallel_group().id + self.use_rope: bool = fd_config.model_config.use_rope + self.rope_head_dim: int = fd_config.model_config.rope_head_dim + self.use_ep: bool = fd_config.parallel_config.use_ep + self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob + self.initializer_range: float = fd_config.model_config.initializer_range + self.sequence_parallel: bool = fd_config.parallel_config.sequence_parallel + self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings + self.freeze_embedding: bool = fd_config.model_config.freeze_embedding + self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings + self.params_dtype: str = params_dtype if self.use_ep: self.word_embeddings = nn.Embedding( @@ -109,7 +114,8 @@ class VocabParallelEmbedding(nn.Layer): self.rope_head_dim_shape_tensor = paddle.ones((self.rope_head_dim), dtype="int8") - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict[str, + paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. @@ -125,7 +131,7 @@ class VocabParallelEmbedding(nn.Layer): get_tensor(state_dict.pop(self.prefix + ".weight")).astype( paddle.get_default_dtype())) - def forward(self, ids_remove_padding=None): + def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ Defines the forward computation of the layer. diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 7eb2cca0a..054ffe7f8 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -216,6 +216,14 @@ class ReplicatedLinear(LinearBase): with_bias=with_bias, add_bias=add_bias, skip_quant=skip_quant) + + self.hidden_size = fd_config.model_config.hidden_size + self.linear_weight_shape = [ + self.input_size, + self.output_size, + ] + if fd_config.quant_config: + self.quant_method.create_weights(self) self.init_weight() @@ -259,7 +267,10 @@ class ColumnParallelLinear(LinearBase): skip_quant=skip_quant) self.nranks = fd_config.parallel_config.tensor_parallel_degree self.input_size = input_size - self.output_size = divide(output_size, self.nranks) + self.output_size = divide( + output_size, + self.nranks) # Split the output_size using TP inference. + self.hidden_size = fd_config.model_config.hidden_size self.linear_weight_shape = [ self.input_size, self.output_size, @@ -339,7 +350,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): """ self.use_fast_ffn = use_fast_ffn self.activation = activation - self.embed_dim = fd_config.model_config.hidden_size + self.hidden_size = fd_config.model_config.hidden_size self.nranks = fd_config.parallel_config.tensor_parallel_degree super().__init__(fd_config=fd_config, @@ -413,12 +424,12 @@ class QKVParallelLinear(ColumnParallelLinear): """ self.num_heads = fd_config.model_config.num_attention_heads self.kv_num_heads = fd_config.model_config.num_key_value_heads - self.embed_dim = fd_config.model_config.hidden_size + self.hidden_size = fd_config.model_config.hidden_size self.head_dim = fd_config.model_config.head_dim self.nranks = fd_config.parallel_config.tensor_parallel_degree self.num_heads_per_rank = divide(self.num_heads, self.nranks) self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks) - input_size = self.embed_dim + input_size = self.hidden_size output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim super().__init__(fd_config=fd_config, prefix=prefix, @@ -448,7 +459,7 @@ class QKVParallelLinear(ColumnParallelLinear): weight_tensor = weight_tensor.reshape([ (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * (self.head_dim), - self.embed_dim, + self.hidden_size, ]) weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0]) @@ -513,6 +524,7 @@ class RowParallelLinear(LinearBase): output_size: int = None, with_bias: bool = False, add_bias: bool = False, + reduce_results: bool = True, skip_quant: bool = False, ): """ @@ -538,10 +550,14 @@ class RowParallelLinear(LinearBase): self.fd_config = fd_config self.skip_quant = False self.nranks = fd_config.parallel_config.tensor_parallel_degree - self.embed_dim = fd_config.model_config.hidden_size + self.hidden_size = fd_config.model_config.hidden_size self.head_dim = fd_config.model_config.head_dim self.num_heads = fd_config.model_config.num_attention_heads // self.nranks + # Split input_size when using TP inference. + self.input_size = divide(input_size, self.nranks) + self.output_size = output_size + self.linear_weight_shape = [ self.input_size, self.output_size, @@ -551,6 +567,8 @@ class RowParallelLinear(LinearBase): if fd_config.quant_config: self.quant_method = fd_config.quant_config.get_quant_method(self) self.quant_method.create_weights(self) + + self.reduce_results = reduce_results self.init_weight() def init_weight(self): @@ -570,7 +588,7 @@ class RowParallelLinear(LinearBase): self.linear_bias = None if self.with_bias: self.linear_bias = self.create_parameter( - shape=[self.embed_dim], + shape=[self.hidden_size], dtype=self._dtype, is_bias=True, ) @@ -589,7 +607,7 @@ class RowParallelLinear(LinearBase): else: out = paddle.matmul(x, self.linear_weight) - if self.nranks > 1: + if self.reduce_results and self.nranks > 1: tensor_model_parallel_all_reduce(out) return out diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 9c6a89ca8..1fac83f89 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -14,10 +14,15 @@ # limitations under the License. """ +from typing import Dict, Optional + +import numpy as np import paddle from paddle import nn from paddle.distributed import fleet +from fastdeploy.config import FDConfig + from .utils import get_tensor @@ -28,12 +33,12 @@ class ParallelLMHead(nn.Layer): def __init__( self, - fd_config, - num_embeddings, - embedding_dim, - prefix="", - with_bias=False, - ): + fd_config: FDConfig, + num_embeddings: int, + embedding_dim: int, + prefix: str = "", + with_bias: bool = False, + ) -> None: """ Parallelized LMhead. @@ -43,21 +48,22 @@ class ParallelLMHead(nn.Layer): num_attention_heads, and ffn_hidden_size. num_embeddings (int): vocabulary size. embedding_dim (int): size of hidden state. - prefix (str): full name of the layer in the state dict + prefix (str): The name of current layer. Defaults to "". + with_bias (bool): whether to have bias. Default: False. """ super(ParallelLMHead, self).__init__() - self.linear_weight_key = prefix + ".weight" + self.linear_weight_key: str = prefix + ".weight" if with_bias: - self.linear_bias_key = prefix + ".bias" + self.linear_bias_key: Optional[str] = prefix + ".bias" else: - self.linear_bias_key = None - self.use_ep = fd_config.parallel_config.use_ep + self.linear_bias_key: Optional[str] = None + self.use_ep: bool = fd_config.parallel_config.use_ep self.column_cut = True ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear - self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings + self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings if self.use_ep: self.weight = self.create_parameter( @@ -92,7 +98,8 @@ class ParallelLMHead(nn.Layer): fuse_matmul_bias=False, # False diff更小 ) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict[str, + paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. @@ -122,7 +129,7 @@ class ParallelLMHead(nn.Layer): paddle.get_default_dtype()) self.out_linear.bias.set_value(bias) - def forward(self, input): + def forward(self, input: paddle.Tensor) -> paddle.Tensor: """ Defines the forward computation of the layer. diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 73778c02d..3c00ddfe4 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -22,14 +22,35 @@ from paddleformers.utils.log import logger import fastdeploy from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce -from ..utils import get_tensor, create_and_set_parameter +from fastdeploy.platforms import current_platform + +from ..utils import create_and_set_parameter, get_tensor from .fused_moe_backend_base import MoEMethodBase -from fastdeploy.platforms import current_platform if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch - from fastdeploy.model_executor.ops.gpu import moe_expert_reduce - + from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch, + moe_expert_reduce, noaux_tc) + + +# used for deepseek_v3 +def get_moe_scores(gating_output: paddle.Tensor, n_group, topk_group, top_k, + routed_scaling_factor, + e_score_correction_bias) -> paddle.Tensor: + """ + compute moe scores using e_score_correction_bias. + """ + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + scores = noaux_tc( + scores, + scores_with_bias, + n_group, + topk_group, + top_k, + routed_scaling_factor, + ) + return scores + class CutlassMoEMethod(MoEMethodBase): """ @@ -199,23 +220,47 @@ class CutlassMoEMethod(MoEMethodBase): """ Paddle Cutlass compute Fused MoE. """ - ( - permute_input, - token_nums_per_expert, - permute_indices_per_token, - topk_weights, - topk_idx, - expert_idx_per_token, - ) = moe_expert_dispatch( - x, - gate_out, - layer.gate_correction_bias, - (layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale") - else None), # if set, permute_input will be int8_t - layer.top_k, - False, - topk_only_mode=False, - ) + if layer.topk_method == "noaux_tc": + gate_out = get_moe_scores(gate_out, layer.n_group, + layer.topk_group, layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias) + + ( + permute_input, + token_nums_per_expert, + permute_indices_per_token, + topk_weights, + topk_idx, + expert_idx_per_token, + ) = moe_expert_dispatch( + x, + gate_out, + None, # Use layer.gate_correction_bias in get_moe_scores. + (layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale") + else None), # if set, permute_input will be int8_t + layer.top_k, + False, + topk_only_mode=True, + ) + else: + ( + permute_input, + token_nums_per_expert, + permute_indices_per_token, + topk_weights, + topk_idx, + expert_idx_per_token, + ) = moe_expert_dispatch( + x, + gate_out, + layer.gate_correction_bias, + (layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale") + else None), # if set, permute_input will be int8_t + layer.top_k, + False, + topk_only_mode=False, + ) if self.moe_quant_type != "w4a8": # only w4a8 need expert_idx_per_token @@ -234,11 +279,11 @@ class CutlassMoEMethod(MoEMethodBase): permute_indices_per_token, topk_idx, None, - norm_topk_prob=True, + norm_topk_prob=False if layer.topk_method == "noaux_tc" else True, routed_scaling_factor=1.0, ) - if layer.tp_size > 1: + if layer.reduce_results and layer.tp_size > 1: tensor_model_parallel_all_reduce(fused_moe_out) return fused_moe_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index b888c99c3..ceb18edf0 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -195,8 +195,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): hidden_size = layer.hidden_size num_experts = layer.num_experts - gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight) - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 00dca18df..267dab451 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -17,6 +17,7 @@ import paddle from paddle import nn +import fastdeploy from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map, @@ -25,17 +26,24 @@ from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase +try: + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func + + from .triton_moe_kernels import fused_moe_kernel_paddle +except: + pass + class TritonWeightOnlyMoEMethod(QuantMethodBase): """ Use Triton Group Gemm to compute Fused MoE. """ - def __init__(self, quant_method=None): + def __init__(self, quant_config=None): """ Triton Group Gemm to compute Fused MoE. """ - self.quant_method = quant_method + self.quant_config = quant_config self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] self.added_scale_attrs = [ "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" @@ -52,7 +60,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) assert len(ffn1_weights) == layer.num_local_experts assert len(ffn2_weights) == layer.num_local_experts - assert layer.quant_method.quant_config.name() == "wint8" + + algo = layer.quant_method.quant_config.name() + + assert algo == "wint8" + assert ffn1_weights[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2 ] @@ -63,9 +75,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): ffn1_tensor = paddle.stack(ffn1_weights, axis=0) ffn2_tensor = paddle.stack(ffn2_weights, axis=0) - if self.quant_config.name() == "wint8": + if algo == "wint8": max_bound = 127 - elif self.quant_config.name() == "wint4": + elif algo == "wint4": max_bound = 7 for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]): @@ -111,15 +123,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight) - scores = paddle.nn.functional.softmax(gate_out, axis=-1) - - topk_weights, topk_ids = paddle.topk(scores, - k=top_k, - axis=-1, - sorted=False) - topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True) - + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight, + False, + ) intermediate_cache1 = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -139,14 +149,12 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, } - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess - - from .triton_moe_kernels import fused_moe_kernel_paddle - sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) - max_num_tokens_padded = sorted_token_ids.shape[0] - grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * - ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * + ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( x, @@ -158,10 +166,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): sorted_token_ids, expert_ids, num_tokens_post_padded, - moe_intermediate_size * 2, - hidden_size, - max_num_tokens_padded, + max_possible_num_post_padded, token_num * top_k, + N=moe_intermediate_size * 2, + K=hidden_size, stride_am=x.strides[0], stride_ak=x.strides[1], stride_be=layer.moe_ffn1_weight.strides[0], @@ -193,8 +201,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): intermediate_cache2 = paddle.incubate.nn.functional.swiglu( intermediate_cache1) - grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * - ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * + ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( intermediate_cache2, layer.moe_ffn2_weight, @@ -205,10 +214,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): sorted_token_ids, expert_ids, num_tokens_post_padded, - hidden_size, - moe_intermediate_size, - max_num_tokens_padded, + max_possible_num_post_padded, token_num * top_k, + N=hidden_size, + K=moe_intermediate_size, stride_am=intermediate_cache2.strides[0], stride_ak=intermediate_cache2.strides[1], stride_be=layer.moe_ffn2_weight.strides[0], @@ -324,7 +333,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight) scores = paddle.nn.functional.softmax(gate_out, axis=-1) topk_weights, topk_ids = paddle.topk(scores, @@ -352,13 +360,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, } - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess - sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) - max_num_tokens_padded = sorted_token_ids.shape[0] - grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * - ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * + ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) adamard_matrix = create_hadamard_matrix_map[hidden_size] x = paddle.matmul(x.cast("float32"), adamard_matrix) @@ -371,8 +379,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): permute_x = permute_x / quant_activation_scale permute_x = permute_x.astype("float8_e4m3fn") - from .triton_moe_kernels import fused_moe_kernel_paddle - fused_moe_kernel_paddle[grid]( permute_x, layer.moe_ffn1_weight.view(paddle.float8_e4m3fn), @@ -383,10 +389,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): sorted_token_ids, expert_ids, num_tokens_post_padded, - moe_intermediate_size * 2, - hidden_size, - max_num_tokens_padded, + max_possible_num_post_padded, token_num * top_k, + N=moe_intermediate_size * 2, + K=hidden_size, stride_am=x.strides[0], stride_ak=x.strides[1], stride_be=layer.moe_ffn1_weight.strides[0], @@ -426,8 +432,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): intermediate_cache2 = intermediate_cache2 / quant_activation_scale intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn") - grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * - ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * + ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( intermediate_cache2, @@ -439,10 +446,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): sorted_token_ids, expert_ids, num_tokens_post_padded, - hidden_size, - moe_intermediate_size, - max_num_tokens_padded, + max_possible_num_post_padded, token_num * top_k, + N=hidden_size, + K=moe_intermediate_size, stride_am=intermediate_cache2.strides[0], stride_ak=intermediate_cache2.strides[1], stride_be=layer.moe_ffn2_weight.strides[0], diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index ea7d722c7..99e156d61 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -224,6 +224,7 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod): ) from fastdeploy.model_executor.ops.gpu import moe_expert_reduce + fused_moe_out = moe_expert_reduce( ffn_out, topk_weights, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 6bef4fc6a..a14b4e2cc 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -30,10 +30,15 @@ class FusedMoE(nn.Layer): def __init__( self, fd_config, + reduce_results: bool = True, moe_intermediate_size: int = -1, num_experts: int = -1, expert_id_offset: int = 0, top_k: int = -1, + topk_method: str = "", + topk_group: int = -1, + n_group: int = -1, + routed_scaling_factor: float = 1.0, layer_idx: int = -1, moe_tag: str = "", weight_key_map: dict = {}, @@ -49,6 +54,7 @@ class FusedMoE(nn.Layer): self.fd_config = fd_config self.layer_idx = layer_idx + self.reduce_results = reduce_results self.tp_size = fd_config.parallel_config.tensor_parallel_degree self.ep_size = fd_config.parallel_config.expert_parallel_degree @@ -60,28 +66,32 @@ class FusedMoE(nn.Layer): self.hidden_size = fd_config.model_config.hidden_size self.moe_config = fd_config.moe_config - self.num_experts = num_experts self.num_local_experts = self.num_experts // self.ep_size self.moe_intermediate_size = moe_intermediate_size // self.tp_size self.top_k = top_k - self.hidden_size = self.hidden_size - self.moe_intermediate_size = moe_intermediate_size // self.tp_size self.weight_key_map = weight_key_map self.use_method = envs.FD_MOE_BACKEND.lower() self.gate_correction_bias = None self.moe_tag = moe_tag - if self.ep_size > 1: expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts self.expert_id_offset = expert_id_offset - if fd_config.quant_config: - self.quant_method = fd_config.quant_config.get_quant_method(self) + # used for deepseek_v3 + self.topk_method = topk_method + self.topk_group = topk_group + self.n_group = n_group + self.routed_scaling_factor = routed_scaling_factor + + moe_quant_config = fd_config.quant_config + if moe_quant_config: + self.quant_method = moe_quant_config.get_quant_method(self) + self.moe_quant_type = moe_quant_config.name() else: # now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future from .fused_moe_cutlass_backend import CutlassMoEMethod @@ -90,12 +100,78 @@ class FusedMoE(nn.Layer): if self.ep_size > 1: self.quant_method.init_ep(self) + if fd_config.load_config.dynamic_load_weight: + # It's for RL to build model + self.init_moe_weights() + logger.info( f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \ {top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \ , ep_size={self.ep_size}, \ tp_size={self.tp_size}.") + def init_moe_weights(self): + """ + Initialize the weight shapes and parameters for the MoE layer. + Combines weight shape initialization and parameter creation into a single function. + """ + # Initialize weight shapes + self._dtype = self._helper.get_default_dtype() + self.weight_dtype = self._dtype + gate_weight_shape = [self.hidden_size, self.num_experts] + gate_correction_bias_shape = [1, self.num_experts] + + self.gate_weight = self.create_parameter( + shape=gate_weight_shape, + dtype="float32", + ) + if self.moe_config.moe_use_aux_free: + self.gate_correction_bias = self.create_parameter( + shape=gate_correction_bias_shape, + dtype="float32", + ) + ffn1_output_dim = self.moe_intermediate_size * 2 + if self.moe_quant_type in ["fp8", "wint8"]: + ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size] + ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size] + else: + ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim] + ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size] + + # Create parameters + if self.moe_quant_type == "fp8": + #(TODO:gaoziyuan) + pass + else: + self.weight_dtype = "int8" + self.init_weight_only_scale() + + # FFN1 parameters + self.moe_ffn1_weight = self.create_parameter( + shape=ffn1_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + # FFN2 parameters + self.moe_ffn2_weight = self.create_parameter( + shape=ffn2_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + def init_weight_only_scale(self): + """ + Initialize the weight scale. + """ + self.moe_ffn1_weight_scale = self.create_parameter( + shape=[self.num_local_experts, self.moe_intermediate_size * 2], + dtype=self._dtype, + ) + self.moe_ffn2_weight_scale = self.create_parameter( + shape=[self.num_local_experts, self.hidden_size], + dtype=self._dtype, + ) + def load_experts_weight(self, state_dict: dict, ffn1_expert_weight_key: str, ffn2_expert_weight_key: str): diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index 4a0c33f82..ff289524f 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -16,9 +16,10 @@ import triton import triton.language as tl +from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import paddle_use_triton_v2 -@triton.jit +@paddle_use_triton_v2() def fused_moe_kernel_paddle( a_ptr, b_ptr, @@ -31,22 +32,22 @@ def fused_moe_kernel_paddle( num_tokens_post_padded_ptr, # Matrix dimensions - N, - K, - num_tokens_post_padded, + max_possible_num_post_padded, num_valid_tokens, - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, + N: tl.constexpr, + K: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_be: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_asm: tl.constexpr, + stride_ask: tl.constexpr, + stride_bse: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, # Block size for block-wise fp8 quantization group_n: tl.constexpr, group_k: tl.constexpr, @@ -87,7 +88,7 @@ def fused_moe_kernel_paddle( multiplication across different blocks processed by the same expert. """ pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M) + num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 557b01bd8..6d25df345 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -14,10 +14,15 @@ # limitations under the License. """ +from typing import Callable, Dict, Optional + +import numpy as np import paddle from paddle import nn from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm +from fastdeploy.config import FDConfig + from .utils import get_tensor @@ -28,16 +33,16 @@ class RMSNorm(nn.Layer): def __init__( self, - fd_config, - hidden_size, - eps=1e-5, - prefix="", - linear_bias=None, - quant_scale=None, - begin_norm_axis=1, - ): + fd_config: FDConfig, + hidden_size: int, + eps: float = 1e-5, + prefix: str = "", + linear_bias: paddle.Tensor = None, + quant_scale: float = None, + begin_norm_axis: int = 1, + ) -> None: """ - Initializes the normalization layer. + Initializes the RMSNormalization layer. Args: fd_config (FDConfig): Arguments related to inference, containing @@ -45,33 +50,33 @@ class RMSNorm(nn.Layer): num_attention_heads, and ffn_hidden_size. hidden_size (int) : size of hidden state. eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5. - weight_key (str): Key name of weight in the pdparams state dict. Defaults to None, means no weight. - bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias. - linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None. + prefix(str,optional):The name of current layer. Defaults to "". + linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None. + quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization. + begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1. Raises: NotImplementedError: If the specified norm_type is not supported. """ super().__init__() self.fd_config = fd_config - self.prefix = prefix - self.hidden_size = hidden_size + self.prefix: str = prefix + self.hidden_size: int = hidden_size if len(prefix) == 0: - self.weight_key = None + self.weight_key: Optional[str] = None else: - self.weight_key = f"{prefix}.weight" - self.with_weight = self.weight_key is not None - self.eps = eps - self.norm_func = fused_rms_norm - self.linear_bias = linear_bias - self.quant_scale = quant_scale - self._dtype = self._helper.get_default_dtype() - self._norm_weight_dtype = self._dtype - self.begin_norm_axis = begin_norm_axis - self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 - self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 - self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 - self.begin_norm_axis = begin_norm_axis + self.weight_key: Optional[str] = f"{prefix}.weight" + self.with_weight: bool = self.weight_key is not None + self.eps: float = eps + self.norm_func: Callable = fused_rms_norm + self.linear_bias: Optional[paddle.Tensor] = linear_bias + self.quant_scale: Optional[float] = quant_scale + self._dtype: str = self._helper.get_default_dtype() + self._norm_weight_dtype: str = self._dtype + self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 + self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 + self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 + self.begin_norm_axis: int = begin_norm_axis self.init_weight() @@ -88,7 +93,8 @@ class RMSNorm(nn.Layer): dtype=self._norm_weight_dtype, ) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict[str, + paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. @@ -102,7 +108,10 @@ class RMSNorm(nn.Layer): self._norm_weight_dtype) self.ln_weight.set_value(weight_tensor) - def forward(self, x, residual_input=None): + def forward( + self, + x, + residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor: """ Defines the forward computation of the layer. @@ -140,18 +149,18 @@ class RMSNorm(nn.Layer): class LayerNorm(nn.Layer): """ - Normalization layer. + Initializes the LayerNormalization layer """ def __init__( self, - fd_config, - hidden_size, - eps=1e-5, + fd_config: FDConfig, + hidden_size: int, + eps: float = 1e-5, prefix="", - linear_bias=None, - quant_scale=None, - with_bias=False, + linear_bias: paddle.Tensor = None, + quant_scale: float = None, + with_bias: bool = False, ): """ Initializes the normalization layer. @@ -160,35 +169,37 @@ class LayerNorm(nn.Layer): fd_config (FDConfig): Arguments related to inference, containing attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim, num_attention_heads, and ffn_hidden_size. - prefix (str): Unique name of the layer, used for naming internal attributes, - you can give it any name you like. hidden_size (int) : size of hidden state. eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5. + prefix (str): Unique name of the layer, used for naming internal attributes, + you can give it any name you like. linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None. + quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization. + with_bias (bool):Whether to include bias or not. Defaults to False. Raises: NotImplementedError: If the specified norm_type is not supported. """ super().__init__() self.fd_config = fd_config - self.prefix = prefix - self.hidden_size = hidden_size + self.prefix: str = prefix + self.hidden_size: int = hidden_size if len(prefix) == 0: - self.weight_key = None + self.weight_key: Optional[str] = None else: - self.weight_key = f"{prefix}.weight" - self.with_weight = self.weight_key is not None - self.bias_key = f"{prefix}.bias" - self.with_bias = with_bias - self.eps = eps + self.weight_key: Optional[str] = f"{prefix}.weight" + self.with_weight: bool = self.weight_key is not None + self.bias_key: str = f"{prefix}.bias" + self.with_bias: bool = with_bias + self.eps: float = eps + self.quant_scale: float = quant_scale + self.norm_func: Callable = fused_layer_norm + self.linear_bias: Optional[paddle.Tensor] = linear_bias + self._dtype: str = self._helper.get_default_dtype() + self._norm_weight_dtype: str = "float32" - self.norm_func = fused_layer_norm - self.linear_bias = linear_bias - self._dtype = self._helper.get_default_dtype() - self._norm_weight_dtype = "float32" - - self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 - self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 - self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 + self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 + self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 + self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 self.init_weight() @@ -212,7 +223,8 @@ class LayerNorm(nn.Layer): dtype=self._norm_weight_dtype, ) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict[str, + paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. @@ -233,7 +245,10 @@ class LayerNorm(nn.Layer): self._norm_weight_dtype) self.ln_bias.set_value(bias_tensor) - def forward(self, x, residual_input=None): + def forward( + self, + x, + residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor: """ Defines the forward computation of the layer. @@ -259,7 +274,7 @@ class LayerNorm(nn.Layer): begin_norm_axis=1, bias=self.linear_bias, residual=residual_input, - quant_scale=-1, + quant_scale=-1 if self.quant_scale is None else self.quant_scale, quant_round_type=self.quant_round_type, quant_max_bound=self.quant_max_bound, quant_min_bound=self.quant_min_bound, diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index f0bc3fc11..9e890853b 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -132,18 +132,14 @@ class WeightOnlyLinearMethod(QuantMethodBase): self.quant_config = quant_config def create_weights(self, layer): + + # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. + linear_weight_scale_shape = [layer.linear_weight_shape[1]] + layer.linear_weight_shape.reverse() if self.quant_config.name() == "wint4": layer.linear_weight_shape[0] //= 2 layer.weight_dtype = "int8" - linear_weight_scale_shape = [layer.embed_dim] - if hasattr(layer, "linear_weight_shape"): - if isinstance(layer.linear_weight_shape, list): - layer_weight_shape = layer.linear_weight_shape - linear_weight_scale_shape = layer_weight_shape[:1] - if self.quant_config.name() == "wint4": - linear_weight_scale_shape[0] *= 2 - layer.linear_weight_scale = layer.create_parameter( shape=linear_weight_scale_shape, dtype=layer._dtype, @@ -195,6 +191,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): weight_scale.astype(paddle.get_default_dtype())) def process_loaded_weights(self, layer, weight) -> None: + quanted_weight_tensor, weight_scale_tensor = weight_quantize( weight, algo=self.quant_config.algo, diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index de3ded87e..0521c4166 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -14,13 +14,18 @@ # limitations under the License. """ -from typing import Optional +import math +from typing import Optional, Tuple import paddle +import paddle.nn as nn from fastdeploy.config import ModelConfig from fastdeploy.platforms import current_platform +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding + from .utils import CpuGuard @@ -99,20 +104,164 @@ class QwenRotaryEmbedding: return rot_emb +def yarn_get_mscale(scale=1, mscale=1): + """ + """ + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_find_correction_dim(num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + """ + """ + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +def yarn_find_correction_range(low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + """ + """ + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_linear_ramp_mask(min, max, dim): + """ + """ + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max - + min) + ramp_func = paddle.clip(linear_func, 0, 1) + return ramp_func + + +class DeepseekScalingRotaryEmbedding(nn.Layer): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + + Args: + rotary_dim(int): Dimension of rotary embeddings (head dimension) + max_position_embeddings(int): Original training context length + base(float): Base value used to compute the inverse frequencies. + scaling_factor(float): Context extension scaling ratio (target_len / original_len) + extrapolation_factor(float): Weight for extrapolated frequencies (default=1) + attn_factor(float): Attention magnitude scaling factor (default=1) + beta_fast(int): High-frequency correction cutoff (default=32) + beta_slow(int): Low-frequency correction cutoff (default=1) + mscale(float): Primary magnitude scaling factor (default=1) + mscale_all_dim(float): Alternate magnitude scaling factor (default=0) + + """ + + def __init__( + self, + rotary_dim: int, + max_position_embeddings: int, + base: int, + scaling_factor: float, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + super().__init__() + self._dtype = paddle.get_default_dtype() + + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + + cache = self._compute_cos_sin_cache() + + self.cos_sin_cache: paddle.Tensor + self.register_buffer("cos_sin_cache", cache, persistable=True) + + def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor: + pos_freqs = self.base**( + paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / + self.rotary_dim) + + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> paddle.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = paddle.arange(self.max_position_embeddings * self.scaling_factor, + dtype=paddle.float32) + freqs = paddle.einsum("i,j->ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = paddle.concat((cos, sin), axis=-1) + return cache.cast(self._dtype) + + def forward( + self, + position_ids: paddle.Tensor, + query: paddle.Tensor, + key: paddle.Tensor, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + """ + # In-place operations that update the query and key tensors. + fused_rotary_position_encoding(query, key, position_ids, + self.cos_sin_cache, self.rotary_dim, + False) + + return query, key + + def get_rope_impl( rotary_dim: int, base: 10000.0, - position_ids, + position_ids: paddle.Tensor, model_config: Optional[ModelConfig] = None, partial_rotary_factor=1, -): +) -> paddle.Tensor: """ The real implementation of get_rope """ architecture = model_config.architectures[0] - if model_config is not None and model_config is None or architecture.startswith( - "Qwen"): + if model_config is None or architecture.startswith("Qwen"): rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb = rotary_emb_layer(position_ids) @@ -126,10 +275,10 @@ def get_rope_impl( def get_rope_xpu( rotary_dim: int, base: 10000.0, - position_ids, - model_config: ModelConfig, + position_ids: paddle.Tensor, + model_config: Optional[ModelConfig] = None, partial_rotary_factor=1, -): +) -> paddle.Tensor: """ In XPU, cos and sin compute must be done on cpu """ @@ -143,12 +292,27 @@ def get_rope_xpu( def get_rope( rotary_dim: int, base: 10000.0, - position_ids, - model_config: ModelConfig, - partial_rotary_factor=1, -): + position_ids: paddle.Tensor, + model_config: Optional[ModelConfig] = None, + partial_rotary_factor: int = 1, +) -> paddle.Tensor: """ - The warpper of get_rope + Pre-calculate rotary position embedding for position_ids. + + Args: + rotary_dim (int): + Dimension of rotary embeddings (head dimension) + base (float, optional): + Base value used to compute the inverse frequencies. + Default: 10000.0. + position_ids (paddle.Tensor): + Tensor containing position indices of input tokens. + model_config (Optional[ModelConfig]): + Model configuration object containing architecture information. + If provided, determines RoPE implementation based on model architecture. + partial_rotary_factor (int, optional): + Factor controlling partial rotary application. + Default: 1 (apply to all dimensions). """ if current_platform.is_xpu(): return get_rope_xpu(rotary_dim, base, position_ids, model_config, @@ -255,7 +419,24 @@ def get_rope_3d( paritial_rotary_factor: 1, max_position: 131072, freq_allocation: 2, -): +) -> paddle.Tensor: + """ + Pre-calculate rotary position embedding for position_ids. + + Args: + rotary_dim (int): + Dimension of rotary embeddings (head dimension) + base (float, optional): + Base value used to compute the inverse frequencies. + Default: 10000.0. + position_ids (paddle.Tensor): + Tensor containing position indices of input tokens. + partial_rotary_factor (int, optional): + Factor controlling partial rotary application. + Default: 1 (apply to all dimensions). + max_position: Maximum position index to precompute. + freq_allocation: Number of rotary dimensions allocated to temporal axis + """ rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base, paritial_rotary_factor, max_position, diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index 3cf9910d1..255c17e7a 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -377,4 +377,4 @@ def create_and_set_parameter(layer: nn.Layer, name: str, dtype=tensor.dtype, default_initializer=paddle.nn.initializer.Constant(0), )) - getattr(layer, name).set_value(tensor) + getattr(layer, name).set_value(tensor) \ No newline at end of file diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py new file mode 100644 index 000000000..a5c84a365 --- /dev/null +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -0,0 +1,289 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import os + +import paddle +import paddle.distributed as dist +from fastsafetensors import SafeTensorsFileLoader, SingleGroup +from paddleformers.transformers import PretrainedModel +from paddleformers.transformers.model_utils import load_tp_checkpoint +from safetensors import safe_open +from tqdm import tqdm + +from fastdeploy.config import FDConfig, ModelConfig +from fastdeploy.model_executor.models.tp_utils import \ + check_tensor_parallel_prerequisites +from fastdeploy.platforms import current_platform + + +def load_ep_checkpoint(model_path: str, + config: ModelConfig, + return_numpy: bool = False): + """ + load ep checkpoint + """ + with open(os.path.join(model_path, "model.safetensors.index.json"), + "r") as f: + weight_list = json.load(f)["weight_map"] + filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k} + num_local_ffn_keys = [] + + for i in range(config.moe_layer_start_index, config.num_layers): + for j in range( + config.num_experts_start_offset, + config.num_experts_start_offset + config.num_experts_per_rank, + ): + ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" + ffn2_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight") + + ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" + ffn2_quant_key = ( + f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight") + + ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" + ffn2_scale_key = ( + f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale") + num_local_ffn_keys.append(ffn1_key) + num_local_ffn_keys.append(ffn2_key) + num_local_ffn_keys.append(ffn1_quant_key) + num_local_ffn_keys.append(ffn2_quant_key) + num_local_ffn_keys.append(ffn1_scale_key) + num_local_ffn_keys.append(ffn2_scale_key) + + for k in num_local_ffn_keys: + if k in weight_list: + filtered_map[k] = weight_list[k] + + state_dict = {} + # Get all safetensor file paths that need to be opened + safetensor_paths = set(filtered_map.values()) + + # Open each safetensor file sequentially with progress bar + for safetensor_path in tqdm(safetensor_paths, + desc="Loading safetensor files", + unit="file"): + with safe_open(os.path.join(model_path, safetensor_path), + framework="np", + device="cpu") as f: + # Check if this file contains keys from filtered_map + for k in filtered_map: + if filtered_map[k] == safetensor_path and k in f.keys(): + weight = f.get_tensor(k) + if not return_numpy: + weight = paddle.Tensor(weight, zero_copy=True) + weight = weight._copy_to( + paddle.framework._current_expected_place(), False) + state_dict[k] = weight + return state_dict + + +def safetensors_weights_iterator(safe_tensor_list: list[str], ): + """ + safetensors_weights_iterator + """ + for st_file in tqdm( + safe_tensor_list, + desc="Loading safetensors checkpoint shards", + ): + with safe_open(st_file, framework="np") as f: + for name in f.keys(): + param = f.get_tensor(name) + yield name, param + + +def fastsafetensors_weights_iterator(safetensor_list: list[str], ): + """ + Return an iterator over tensors on GPU from a given safetensor_list. + """ + world_size = dist.get_world_size() + if world_size > 1: + pg = dist.get_group() + device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu" + else: + pg = SingleGroup() + device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda( + ) else "cpu" + + safetensor_files_sub_lists = [ + safetensor_list[i:i + world_size] + for i in range(0, len(safetensor_list), world_size) + ] + + for st_file in tqdm( + safetensor_files_sub_lists, + desc="Loading fastsafetensors checkpoint shards", + ): + loader = SafeTensorsFileLoader(pg, + device, + nogds=True, + debug_log=False, + framework="paddle") + rank_file_map = {i: [f] for i, f in enumerate(st_file)} + loader.add_filenames(rank_file_map) + try: + fb = loader.copy_files_to_device() + try: + keys = list(fb.key_to_rank_lidx.keys()) + for k in keys: + t = fb.get_tensor(k) + yield k, t + finally: + fb.close() + finally: + loader.close() + + +def load_pre_sharded_checkpoint(model_path: str, + local_rank: int, + use_fastsafetensor: bool = False): + """ + load_pre_sharded_checkpoint + """ + state_dict = {} + _, safetensor_files = get_all_safetensors( + os.path.join(model_path, f"rank{local_rank}")) + weights_iterator = safetensors_weights_iterator(safetensor_files) + for name, weight in weights_iterator: + state_dict[name] = weight + return state_dict + + +def get_all_safetensors(model_path: str): + """ + get_all_safetensors + """ + safe_model_path = os.path.join(model_path, "model.safetensors") + if os.path.exists(safe_model_path): + safetensor_list = [safe_model_path] + with safe_open(safe_model_path, framework="np", device="cpu") as f: + key_name_list = f.keys() + return key_name_list, safetensor_list + else: + with open(os.path.join(model_path, "model.safetensors.index.json"), + "r") as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add( + os.path.join(model_path, weight_map[weight_name])) + key_name_list = list(set(weight_map.keys())) + safetensor_list = list(weight_files_in_index) + safetensor_list.sort() + return key_name_list, safetensor_list + + +def load_tp_checkpoint_v1( + model_path: str, + cls: PretrainedModel, + fd_config: FDConfig, + use_fastsafetensor: bool = True, +): + """ + load_tp_checkpoint_v1 + """ + + safetensor_keys, safetensor_files = get_all_safetensors(model_path) + + if use_fastsafetensor: + weights_iterator = fastsafetensors_weights_iterator(safetensor_files) + else: + weights_iterator = safetensors_weights_iterator(safetensor_files) + + tensor_parallel_filtered_map = {} + check_tensor_parallel_prerequisites( + fd_config, + cls, + tensor_parallel_filtered_map, + safetensor_keys, + ) + need_tp = True if tensor_parallel_filtered_map else False + state_dict = {} + for key, weight in weights_iterator: + paddle.device.cuda.synchronize() + if need_tp and key in tensor_parallel_filtered_map: + action = tensor_parallel_filtered_map.pop(key) + tensor = action(weight).clone() + else: + tensor = weight.clone() + state_dict[key] = tensor + weight.value().get_tensor()._clear() + return state_dict + + +def deal_state_dict(state_dict): + """deal_state_dict""" + device = paddle.CUDAPinnedPlace() + for name, src in state_dict.items(): + if src._is_initialized() and not isinstance(src.place, + paddle.CUDAPinnedPlace): + dst = src._copy_to(device, True) + dst_tensor = dst.value().get_tensor() + src_tensor = src.value().get_tensor() + src_tensor._clear() + src_tensor._share_data_with(dst_tensor) + + +def load_composite_checkpoint( + model_path: str, + cls: PretrainedModel, + fd_config: FDConfig, + return_numpy=True, +): + """ + # This method supports loading model weights under three parallelism strategies: + # 1. Expert Parallel (EP) + # 2. Tensor Parallel (TP) + # 3. Pre-sharded (pre-split) + """ + if fd_config.parallel_config.use_ep: + state_dict = load_ep_checkpoint(model_path, + fd_config.model_config, + return_numpy=True) + else: + rank_dirs = [ + f for f in os.listdir(model_path) if f.startswith("rank") + and os.path.isdir(os.path.join(model_path, f)) + ] + if len(rank_dirs) > 1: + if fd_config.parallel_config.tensor_parallel_degree != len( + rank_dirs): + raise ValueError( + f"Your model only supports loading with tp{len(rank_dirs)}" + ) + state_dict = load_pre_sharded_checkpoint( + model_path, + fd_config.parallel_config.tensor_parallel_rank, + use_fastsafetensor=False, + ) + else: + if fd_config.load_config.use_fastsafetensor and ( + current_platform.available() + and current_platform.is_cuda()): + state_dict = load_tp_checkpoint_v1(model_path, + cls, + fd_config, + use_fastsafetensor=True) + deal_state_dict(state_dict) + else: + state_dict = load_tp_checkpoint(model_path, + cls, + fd_config.model_config, + return_numpy=return_numpy) + if not state_dict: + raise ValueError("weight not found in state_dict !") + return state_dict diff --git a/fastdeploy/model_executor/model_loader.py b/fastdeploy/model_executor/model_loader.py index 604ba73e9..2010c2021 100644 --- a/fastdeploy/model_executor/model_loader.py +++ b/fastdeploy/model_executor/model_loader.py @@ -20,6 +20,10 @@ import paddle from paddle import nn from fastdeploy.config import FDConfig, LoadConfig, ModelConfig +from fastdeploy.model_executor.load_weight_utils import \ + load_composite_checkpoint +from fastdeploy.model_executor.models.deepseek_v3 import \ + DeepSeekV3PretrainedModel from fastdeploy.model_executor.models.ernie4_5_moe import \ Ernie4_5_PretrainedModel from fastdeploy.model_executor.models.ernie4_5_mtp import \ @@ -28,7 +32,7 @@ from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel -from fastdeploy.model_executor.models.utils import load_checkpoint +from fastdeploy.platforms import current_platform MODEL_CLASSES = { "Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel, @@ -36,7 +40,8 @@ MODEL_CLASSES = { "Qwen2ForCausalLM": Qwen2PretrainedModel, "Qwen3ForCausalLM": Qwen3PretrainedModel, "Qwen3MoeForCausalLM": Qwen3MoePretrainedModel, - "Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel + "Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel, + "DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel, } @@ -73,23 +78,38 @@ class DefaultModelLoader(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: pass + def clean_memory_fragments(self, state_dict: dict) -> None: + """clean_memory_fragments""" + if current_platform.is_cuda(): + if state_dict: + for k, v in state_dict.items(): + if isinstance(v, paddle.Tensor): + v.value().get_tensor()._clear() + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + def load_model(self, fd_config: FDConfig) -> nn.Layer: context = paddle.LazyGuard() architectures = fd_config.model_config.architectures[0] - # TODO(gongshaotian): Now, only support safetensor - model_class = MODEL_CLASSES[architectures] - state_dict = load_checkpoint( - fd_config.parallel_config.model_name_or_path, - model_class, - fd_config.model_config, - return_numpy=True) + with context: model_cls = ModelRegistry.get_class(architectures) model = model_cls(fd_config) model.eval() - model.set_state_dict(state_dict) + # RL model not need set_state_dict + if fd_config.load_config.dynamic_load_weight: + return model + + state_dict = load_composite_checkpoint( + fd_config.parallel_config.model_name_or_path, + model_class, + fd_config, + return_numpy=True, + ) + model.set_state_dict(state_dict) + self.clean_memory_fragments(state_dict) return model diff --git a/fastdeploy/model_executor/models/__init__.py b/fastdeploy/model_executor/models/__init__.py index e26e82474..2ace118ef 100644 --- a/fastdeploy/model_executor/models/__init__.py +++ b/fastdeploy/model_executor/models/__init__.py @@ -20,15 +20,6 @@ from pathlib import Path from .model_base import ModelForCasualLM, ModelRegistry -inference_runner_supported_models = [ - "Ernie4_5_MoeForCausalLM", - "Ernie4_5_MTPForCausalLM", - "Qwen2ForCausalLM", - "Qwen3MoeForCausalLM", - "Ernie4_5_ForCausalLM", - "Qwen3ForCausalLM", -] - def _find_py_files(root_dir): root_path = Path(root_dir) @@ -44,22 +35,23 @@ def _find_py_files(root_dir): return py_files -def auto_models_registry(): +def auto_models_registry(dir_path, + register_path="fastdeploy.model_executor.models", + suffix=""): """ auto registry all models in this folder """ - for module_file in _find_py_files(os.path.dirname(__file__)): + for module_file in _find_py_files(dir_path): try: - module = importlib.import_module( - f'fastdeploy.model_executor.models.{module_file}') + module = importlib.import_module(f'{register_path}.{module_file}') for attr_name in dir(module): attr = getattr(module, attr_name) if inspect.isclass(attr) and issubclass( attr, ModelForCasualLM) and attr is not ModelForCasualLM: - ModelRegistry.register(attr) + ModelRegistry.register(attr, suffix=suffix) except ImportError: raise ImportError(f"{module_file=} import error") -auto_models_registry() +auto_models_registry(os.path.dirname(__file__)) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py new file mode 100644 index 000000000..8286e77d8 --- /dev/null +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -0,0 +1,763 @@ +""" +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import math +from functools import partial + +import paddle +from paddle import nn +from paddleformers.transformers import PretrainedModel +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.distributed.communication_op import \ + tensor_model_parallel_all_reduce +from fastdeploy.model_executor.layers.activation import SiluAndMul +from fastdeploy.model_executor.layers.attention import Attention +from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding +from fastdeploy.model_executor.layers.linear import ( + ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear, + ReplicatedLinear, RowParallelLinear) +from fastdeploy.model_executor.layers.lm_head import ParallelLMHead +from fastdeploy.model_executor.layers.moe.moe import FusedMoE +from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.layers.rotary_embedding import \ + DeepseekScalingRotaryEmbedding +from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.platforms import current_platform +from fastdeploy.worker.forward_meta import ForwardMeta + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import \ + get_position_ids_and_mask_encoder_batch + + +class DeepSeekV3MLP(nn.Layer): + """ + DeepSeekV3MLP, for Dense FFN and Shared Experts Layer. + """ + + def __init__( + self, + fd_config: FDConfig, + intermediate_size: int, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + + self.gate_up_proj = MergedColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.up_gate_proj", + input_size=fd_config.model_config.hidden_size, + output_size=intermediate_size * 2, + with_bias=False, + activation=fd_config.model_config.hidden_act, + use_fast_ffn=True, + ) + + self.down_proj = RowParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.down_proj", + input_size=intermediate_size, + output_size=fd_config.model_config.hidden_size, + with_bias=False, + reduce_results=reduce_results, + ) + + self.act_fn = SiluAndMul( + fd_config=fd_config, + bias=None, + act_method=fd_config.model_config.hidden_act, + ) + + def load_state_dict(self, state_dict): + """ + """ + self.gate_up_proj.load_state_dict(state_dict) + self.down_proj.load_state_dict(state_dict) + + def forward(self, x): + """ + """ + gate_up_out = self.gate_up_proj(x) + act_out = self.act_fn(gate_up_out) + down_out = self.down_proj(act_out) + return down_out + + +class DeepSeekV3MoE(nn.Layer): + """ + DeepSeekV3MoE, for MoE Layer. + """ + + def __init__(self, fd_config: FDConfig, layer_id: int, + prefix: str) -> None: + super().__init__() + + self.tp_size = fd_config.parallel_config.tensor_parallel_degree + + weight_key_map = { + "gate_weight_key": f"{prefix}.gate.weight", + "gate_correction_bias_key": + f"{prefix}.gate.e_score_correction_bias", + "ffn1_expert_weight_key": + f"{prefix}.experts.{{}}.up_gate_proj.weight", + "ffn2_expert_weight_key": + f"{prefix}.experts.{{}}.down_proj.weight", + } + + self.fused_moe = FusedMoE( + fd_config=fd_config, + reduce_results=False, + moe_intermediate_size=fd_config.model_config.deepseekv3. + moe_intermediate_size, + num_experts=fd_config.model_config.deepseekv3.n_routed_experts, + top_k=fd_config.model_config.deepseekv3.num_experts_per_tok, + topk_method=fd_config.model_config.deepseekv3.topk_method, + topk_group=fd_config.model_config.deepseekv3.topk_group, + n_group=fd_config.model_config.deepseekv3.n_group, + routed_scaling_factor=fd_config.model_config.deepseekv3. + routed_scaling_factor, + layer_idx=layer_id, + weight_key_map=weight_key_map, + ) + + self.num_shared_experts = fd_config.model_config.deepseekv3.n_shared_experts + shared_experts_intermediate_size = ( + self.num_shared_experts * + fd_config.model_config.deepseekv3.moe_intermediate_size) + + self.shared_experts = DeepSeekV3MLP( + fd_config=fd_config, + intermediate_size=shared_experts_intermediate_size, + prefix=f"{prefix}.shared_experts", + reduce_results=False, + ) + + def load_state_dict(self, state_dict): + """ + """ + self.fused_moe.load_state_dict(state_dict) + self.shared_experts.load_state_dict(state_dict) + + def forward(self, hidden_states: paddle.Tensor): + """ + """ + shared_experts_out = self.shared_experts(hidden_states) + moe_out = self.fused_moe(hidden_states) + moe_out = moe_out + shared_experts_out + # We do to TP all reduce after the sum of experts. + if self.tp_size > 1: + tensor_model_parallel_all_reduce(moe_out) + return moe_out + + +class DeepseekV3MLAAttention(nn.Layer): + """ + DeepseekV3MLAAttention + """ + + def __init__(self, + fd_config: FDConfig, + layer_id: int, + prefix: str = "") -> None: + super().__init__() + + self.tp_size = fd_config.parallel_config.tensor_parallel_degree + self.hidden_size = fd_config.model_config.hidden_size + self.num_attention_heads = fd_config.model_config.num_attention_heads + self.num_attention_heads_tp = self.num_attention_heads // self.tp_size + + # MLA + self.qk_nope_head_dim = fd_config.model_config.deepseekv3.qk_nope_head_dim + self.qk_rope_head_dim = fd_config.model_config.deepseekv3.qk_rope_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim = fd_config.model_config.deepseekv3.v_head_dim + self.q_lora_rank = fd_config.model_config.deepseekv3.q_lora_rank + self.kv_lora_rank = fd_config.model_config.deepseekv3.kv_lora_rank + + self.attn_softmax_scale = self.qk_head_dim**-0.5 + self.rope_theta = fd_config.model_config.rope_theta + self.rms_norm_eps = fd_config.model_config.rms_norm_eps + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(fd_config=fd_config, + prefix=f"{prefix}.q_a_proj", + input_size=self.hidden_size, + output_size=self.q_lora_rank, + with_bias=False) + + self.q_a_layernorm = RMSNorm(fd_config, + hidden_size=self.q_lora_rank, + eps=self.rms_norm_eps, + prefix=f"{prefix}.q_a_layernorm") + + self.q_b_proj = ColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.q_b_proj", + input_size=self.q_lora_rank, + output_size=self.num_attention_heads * self.qk_head_dim, + with_bias=False, + ) + else: + assert (self.q_lora_rank is not None + ), "self.q_lora_rank is None, Please Check your config." + + # 不切TP,跑 W4A16 Gemm + self.kv_a_proj_with_mqa = ReplicatedLinear( + fd_config=fd_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + input_size=self.hidden_size, + output_size=self.kv_lora_rank + self.qk_rope_head_dim, + with_bias=False) + + self.kv_a_layernorm = RMSNorm(fd_config, + hidden_size=self.kv_lora_rank, + eps=self.rms_norm_eps, + prefix=f"{prefix}.kv_a_layernorm") + + self.kv_b_proj = ColumnParallelLinear( + fd_config=fd_config, + prefix=f"{prefix}.kv_b_proj", + input_size=self.kv_lora_rank, + output_size=self.num_attention_heads * + (self.qk_nope_head_dim + self.v_head_dim), + with_bias=False, + ) + + self.o_proj = RowParallelLinear(fd_config, + prefix=f"{prefix}.o_proj", + input_size=self.num_attention_heads * + self.v_head_dim, + output_size=self.hidden_size, + with_bias=False) + + self.kv_b_proj_bmm = KVBatchLinear( + fd_config=fd_config, + prefix=f"{prefix}.kv_b_proj", + kv_lora_rank=self.kv_lora_rank, + num_attention_heads=self.num_attention_heads, + qk_nope_head_dim=self.qk_nope_head_dim, + v_head_dim=self.v_head_dim) + + self.rope_scaling = fd_config.model_config.deepseekv3.rope_scaling + if self.rope_scaling: + mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False) + scaling_factor = self.rope_scaling["factor"] + mscale = self.yarn_get_mscale(scaling_factor, + float(mscale_all_dim)) + self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale + + rope_scaling_kwargs = { + key: self.rope_scaling[key] + for key in [ + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] if key in self.rope_scaling + } + self.rope_scaling_factor = self.rope_scaling["factor"] + self.rope_scaling_original_max_position_embeddings = self.rope_scaling[ + "original_max_position_embeddings"] + self.rotary_emb = DeepseekScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self. + rope_scaling_original_max_position_embeddings, + base=self.rope_theta, + scaling_factor=self.rope_scaling_factor, + **rope_scaling_kwargs, + ) + + self.mla_attn = Attention( + fd_config=fd_config, + layer_id=layer_id, + prefix=prefix, + use_neox_rotary_style=False, + ) + + self.prefix = prefix + + @staticmethod + def yarn_get_mscale(scale=1, mscale=1): + """ + """ + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + position_ids: paddle.Tensor, + mask_encoder_batch: paddle.Tensor, + ): + """ + """ + layernorm_out = hidden_states + fmha_out = paddle.zeros(shape=[ + layernorm_out.shape[0], + self.num_attention_heads_tp * self.v_head_dim + ], + dtype=layernorm_out.dtype) + + decode_stage = forward_meta.is_decode_batch + prefill_stage = not (forward_meta.is_decode_batch) + + if prefill_stage: + query = self.q_a_proj(layernorm_out) + query = self.q_a_layernorm(query) + query = self.q_b_proj(query) + + query = query.reshape( + [-1, self.num_attention_heads_tp, self.qk_head_dim]) + query_nope, query_pe = query.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + + compressed_kv = self.kv_a_proj_with_mqa(layernorm_out) + compressed_kv, key_pe = compressed_kv.split( + [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) + key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) + compressed_kv = self.kv_a_layernorm(compressed_kv) + + query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + + key_value = self.kv_b_proj(compressed_kv) + key_value = key_value.reshape([ + -1, self.num_attention_heads_tp, + self.qk_nope_head_dim + self.v_head_dim + ]) + key_nope, value = key_value.split( + [self.qk_nope_head_dim, self.v_head_dim], axis=-1) + + query[..., self.qk_nope_head_dim:] = query_pe + key = paddle.empty_like(query) + key[..., :self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim:] = key_pe + value = paddle.nn.functional.pad( + value, [0, self.qk_head_dim - self.v_head_dim], value=0) + + fmha_out_prefill = self.mla_attn(q=query, + k=key, + v=value, + qkv=None, + compressed_kv=compressed_kv, + k_pe=key_pe, + forward_meta=forward_meta) + + fmha_out_prefill = fmha_out_prefill.reshape( + [-1, self.num_attention_heads_tp, self.qk_head_dim]) + fmha_out_prefill = fmha_out_prefill[:, :, :self.v_head_dim] + fmha_out_prefill = fmha_out_prefill.reshape( + [-1, self.num_attention_heads_tp * self.v_head_dim]) + fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast( + fmha_out_prefill.dtype) + + fmha_out = fmha_out + fmha_out_prefill + + if decode_stage: + query = self.q_a_proj(layernorm_out) + query = self.q_a_layernorm(query) + ln_out_or_q_c = query + + compressed_kv = self.kv_a_proj_with_mqa(layernorm_out) + compressed_kv, key_pe = compressed_kv.split( + [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) + key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) + compressed_kv = self.kv_a_layernorm(compressed_kv) + + query = self.q_b_proj(ln_out_or_q_c) + query = query.reshape( + [-1, self.num_attention_heads_tp, self.qk_head_dim]) + + query_nope, query_pe = query.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + + q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), + proj_type='k').transpose([1, 0, 2]) + + q_input = paddle.concat([q_nope_out, query_pe], axis=-1) + q_input = q_input.reshape([ + -1, + self.num_attention_heads_tp * + (self.kv_lora_rank + self.qk_rope_head_dim), + ]) + fmha_out_decode = self.mla_attn(q=q_input, + k=None, + v=None, + qkv=None, + compressed_kv=compressed_kv, + k_pe=key_pe, + forward_meta=forward_meta) + + fmha_out_decode = fmha_out_decode.reshape( + [-1, self.num_attention_heads_tp, + self.kv_lora_rank]).transpose([1, 0, 2]) + + fmha_out_decode = (self.kv_b_proj_bmm( + fmha_out_decode, proj_type='v').transpose([1, 0, 2]).reshape( + [-1, self.num_attention_heads_tp * self.v_head_dim])) + fmha_out = fmha_out + fmha_out_decode + + output = self.o_proj(fmha_out) + return output + + def load_state_dict(self, state_dict): + """ + """ + self.q_a_proj.load_state_dict(state_dict) + self.q_a_layernorm.load_state_dict(state_dict) + self.kv_a_proj_with_mqa.load_state_dict(state_dict) + self.kv_a_layernorm.load_state_dict(state_dict) + self.q_b_proj.load_state_dict(state_dict) + self.kv_b_proj_bmm.load_state_dict(state_dict) + self.kv_b_proj.load_state_dict(state_dict) + # NOTE(Ryan):Make sure kv_b_proj_bmm loaded before kv_b_proj, + # The same weight key will be poped after kv_b_proj. + self.o_proj.load_state_dict(state_dict) + + +class DeepSeekV3DecoderLayer(nn.Layer): + """ + DeepSeekV3DecoderLayer + """ + + def __init__( + self, + fd_config: FDConfig, + prefix: str = "", + ) -> None: + super().__init__() + layer_id = int(prefix.split(sep='.')[-1]) + + self.self_attn = DeepseekV3MLAAttention( + fd_config=fd_config, + layer_id=layer_id, + prefix=f"{prefix}.self_attn", + ) + + if (fd_config.model_config.deepseekv3.n_routed_experts is not None + and layer_id + >= fd_config.model_config.deepseekv3.first_k_dense_replace): + self.mlp = DeepSeekV3MoE( + fd_config=fd_config, + layer_id=layer_id, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepSeekV3MLP( + fd_config=fd_config, + intermediate_size=fd_config.model_config.intermediate_size, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.input_layernorm", + ) + + self.post_attention_layernorm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.post_attention_layernorm", + ) + + def load_state_dict(self, state_dict): + """ + """ + self.self_attn.load_state_dict(state_dict) + self.mlp.load_state_dict(state_dict) + self.input_layernorm.load_state_dict(state_dict) + self.post_attention_layernorm.load_state_dict(state_dict) + + def forward( + self, + forward_meta: ForwardMeta, + hidden_states: paddle.Tensor, + residual: paddle.Tensor, + position_ids: paddle.Tensor, + mask_encoder_batch: paddle.Tensor, + ): + """ + """ + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn(forward_meta, hidden_states, + position_ids, mask_encoder_batch) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepSeekV3Model(nn.Layer): + """ + DeepSeekV3Model + """ + + def __init__( + self, + fd_config: FDConfig = None, + ): + """ + Initializer for the DeepSeekV3Model class. + """ + super().__init__() + self.num_layers = fd_config.model_config.num_layers + fd_config.model_config.prefix_name = "deepseek_v3" + + self.embeddings = VocabParallelEmbedding( + fd_config, + num_embeddings=fd_config.model_config.vocab_size, + embedding_dim=fd_config.model_config.hidden_size, + params_dtype=paddle.get_default_dtype(), + prefix="deepseek_v3.embed_tokens", + ) + + self.decoder_layers = nn.LayerList([ + DeepSeekV3DecoderLayer( + fd_config, + prefix=f"{fd_config.model_config.prefix_name}.layers.{i}") + for i in range(self.num_layers) + ]) + + self.norm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix="deepseek_v3.norm", + ) + + def pre_process(self, forward_meta): + """ + """ + seq_lens_encoder = forward_meta.seq_lens_encoder + seq_lens_decoder = forward_meta.seq_lens_decoder + seq_lens_this_time = forward_meta.seq_lens_this_time + position_ids_shape = paddle.sum(seq_lens_this_time) + + position_ids = paddle.empty(shape=position_ids_shape, + dtype=seq_lens_encoder.dtype) + mask_encoder_batch = paddle.empty( + shape=position_ids_shape, + dtype=seq_lens_encoder.dtype).unsqueeze(1) + + get_position_ids_and_mask_encoder_batch(seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + position_ids, + mask_encoder_batch) + + return position_ids, mask_encoder_batch + + def load_state_dict(self, state_dict): + """ + Load model parameters from a given state dictionary. + """ + self.embeddings.load_state_dict(state_dict) + self.norm.load_state_dict(state_dict) + for i in range(self.num_layers): + logger.info(f"Start load layer {i}") + self.decoder_layers[i].load_state_dict(state_dict) + + def forward( + self, + ids_remove_padding: paddle.Tensor, + forward_meta: ForwardMeta, + ): + """ + """ + hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) + + position_ids, mask_encoder_batch = self.pre_process(forward_meta) + + residual = None + for i in range(self.num_layers): + hidden_states, residual = self.decoder_layers[i]( + forward_meta, hidden_states, residual, position_ids, + mask_encoder_batch) + hidden_states = hidden_states + residual + out = self.norm(hidden_states) + + return out + + +class DeepseekV3ForCausalLM(ModelForCasualLM): + """ + DeepseekV3ForCausalLM + """ + + def __init__(self, fd_config: FDConfig): + """ + Args: + fd_config (FDConfig): Configurations for the LLM model. + """ + super().__init__(fd_config) + self.model = DeepSeekV3Model(fd_config) + self.ori_vocab_size = fd_config.model_config.ori_vocab_size + self.lm_head = ParallelLMHead( + fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", + ) + + @classmethod + def name(cls): + """ + """ + return "DeepseekV3ForCausalLM" + + @paddle.no_grad() + def set_state_dict(self, state_dict): + """ + Load model parameters from a given state dictionary. + """ + self.model.load_state_dict(state_dict) + self.lm_head.load_state_dict(state_dict) + + def compute_logits(self, hidden_states: paddle.Tensor): + """ + """ + logits = self.lm_head(hidden_states) + logits = paddle.cast(logits, paddle.float32) + logits[:, self.ori_vocab_size:] = -float("inf") + return logits + + def forward( + self, + ids_remove_padding: paddle.Tensor, + forward_meta: ForwardMeta, + ): + """ + """ + hidden_states = self.model(ids_remove_padding, forward_meta) + return hidden_states + + +class DeepSeekV3PretrainedModel(PretrainedModel): + """ + DeepSeekV3PretrainedModel + """ + + config_class = FDConfig + + def _init_weight(self, layer): + """ + _init_weight + """ + return None + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + logger.info("DeepseekV3 inference model _get_tensor_parallel_mappings") + + from paddleformers.transformers.conversion_utils import \ + split_or_merge_func + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, + is_column=False), + } + + # Self Attention Layer which are need TP. + base_actions["layers.0.self_attn.q_b_proj.weight"] = partial( + fn, is_column=True) + base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial( + fn, is_column=True) + base_actions[ + "layers.0.self_attn.q_b_proj.weight_scale_inv"] = partial( + fn, is_column=True) + base_actions[ + "layers.0.self_attn.kv_b_proj.weight_scale_inv"] = partial( + fn, is_column=True) + + # MLP Layer + base_actions["layers.0.mlp.gate_proj.weight"] = partial( + fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial( + fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight"] = partial( + fn, is_column=False) + + # Moe Layer + for expert_idx in range(config.n_routed_experts): + base_actions[ + f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial( + fn, is_column=True) + base_actions[ + f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial( + fn, is_column=True) + base_actions[ + f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial( + fn, is_column=False) + + # Shared Expert Layer + base_actions[ + "layers.0.mlp.shared_experts.up_proj.weight"] = partial( + fn, is_column=True) + base_actions[ + "layers.0.mlp.shared_experts.gate_proj.weight"] = partial( + fn, is_column=True) + base_actions[ + "layers.0.mlp.shared_experts.down_proj.weight"] = partial( + fn, is_column=False) + + # MTP parts + base_actions["layers.61.embed_tokens.weight"] = partial( + fn, is_column=False) + base_actions["layers.61.eh_proj.weight"] = partial(fn, + is_column=True) + base_actions["layers.61.shared_head.head.weight"] = partial( + fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", + f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_layers) + return mappings diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 669438f4f..3a74f4114 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -37,291 +37,13 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm +from fastdeploy.model_executor.models.utils import \ + LayerIdPlaceholder as layerid +from fastdeploy.model_executor.models.utils import WeightMeta from fastdeploy.worker.forward_meta import ForwardMeta -class Ernie4_5_PretrainedModel(PretrainedModel): - """ - Ernie4_5_PretrainedModel - """ - - config_class = FDConfig - - def _init_weight(self, layer): - """ - _init_weight - """ - return None - - @classmethod - def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True): - """ - get_tensor_parallel_mappings - """ - logger.info("erine inference model _get_tensor_parallel_mappings") - - from paddleformers.transformers.conversion_utils import \ - split_or_merge_func - - fn = split_or_merge_func( - is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, - tensor_parallel_rank=config.tensor_parallel_rank, - num_attention_heads=config.num_attention_heads, - ) - - def gqa_qkv_split_func( - weight, - tensor_parallel_degree, - tensor_parallel_rank, - num_attention_heads, - num_key_value_heads, - head_dim, - ): - - def get_shape(tensor): - return (tensor.get_shape() - if hasattr(tensor, "get_shape") else tensor.shape) - - def slice_tensor(tensor, start, end): - shape = get_shape(tensor) - if len(shape) == 1: - return tensor[start:end] - else: - return tensor[..., start:end] - - q_end = num_attention_heads * head_dim - k_end = q_end + num_key_value_heads * head_dim - v_end = k_end + num_key_value_heads * head_dim - - q = slice_tensor(weight, 0, q_end) - k = slice_tensor(weight, q_end, k_end) - v = slice_tensor(weight, k_end, v_end) - - def split_tensor(tensor, degree): - shape = get_shape(tensor) - size = shape[-1] - block_size = size // degree - if hasattr(tensor, "get_shape"): - return [ - slice_tensor(tensor, i * block_size, - (i + 1) * block_size) - for i in range(degree) - ] - else: - return np.split(tensor, degree, axis=-1) - - q_list = split_tensor(q, tensor_parallel_degree) - k_list = split_tensor(k, tensor_parallel_degree) - v_list = split_tensor(v, tensor_parallel_degree) - - if tensor_parallel_rank is None: - return [ - np.concatenate([q_i, k_i, v_i], axis=-1) - for q_i, k_i, v_i in zip(q_list, k_list, v_list) - ] - else: - return np.concatenate( - [ - q_list[tensor_parallel_rank], - k_list[tensor_parallel_rank], - v_list[tensor_parallel_rank], - ], - axis=-1, - ) - - def gqa_qkv_merge_func(weight_list, num_attention_heads, - num_key_value_heads, head_dim): - tensor_parallel_degree = len(weight_list) - num_attention_heads = num_attention_heads // tensor_parallel_degree - num_key_value_heads = num_key_value_heads // tensor_parallel_degree - - is_paddle_tensor = not isinstance(weight_list[0], np.ndarray) - - def get_shape(tensor): - return (tensor.get_shape() - if hasattr(tensor, "get_shape") else tensor.shape) - - def slice_tensor(tensor, start, end): - if len(get_shape(tensor)) == 1: - return tensor[start:end] - else: - return tensor[..., start:end] - - q_list, k_list, v_list = [], [], [] - - for weight in weight_list: - q_end = num_attention_heads * head_dim - k_end = q_end + num_key_value_heads * head_dim - v_end = k_end + num_key_value_heads * head_dim - - q = slice_tensor(weight, 0, q_end) - k = slice_tensor(weight, q_end, k_end) - v = slice_tensor(weight, k_end, v_end) - - q_list.append(q) - k_list.append(k) - v_list.append(v) - - merged = q_list + k_list + v_list - - if is_paddle_tensor: - tensor = paddle.concat(merged, axis=-1) - if tensor.place.is_gpu_place(): - tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) - return tensor - else: - return np.concatenate(merged, axis=-1) - - if (config.num_key_value_heads is not None - and config.num_key_value_heads != config.num_attention_heads): - if is_split: - qkv_fn = partial( - gqa_qkv_split_func, - tensor_parallel_degree=config.tensor_parallel_degree, - tensor_parallel_rank=config.tensor_parallel_rank, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - head_dim=config.head_dim, - ) - else: - qkv_fn = partial( - gqa_qkv_merge_func, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - head_dim=config.head_dim, - ) - else: - qkv_fn = partial(fn, is_column=True) - - def get_tensor_parallel_split_mappings(num_layers, moe_num_experts, - moe_num_shared_experts, - moe_layer_start_index): - - final_actions = {} - - base_model_prefix = "ernie" - base_actions = { - "lm_head.weight": - partial(fn, is_column=True), - # "eh_proj.weight": partial(fn, is_column=True), - f"{base_model_prefix}.embed_tokens.weight": - partial(fn, is_column=False), - } - - base_actions[ - f"{base_model_prefix}.layers.0.self_attn.qkv_proj.weight"] = qkv_fn - base_actions[ - f"{base_model_prefix}.layers.0.self_attn.qkv_proj.quant_weight"] = qkv_fn - base_actions[ - f"{base_model_prefix}.layers.0.self_attn.o_proj.weight"] = partial( - fn, is_column=False) - base_actions[ - f"{base_model_prefix}.layers.0.self_attn.o_proj.quant_weight"] = partial( - fn, is_column=False) - base_actions[ - f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"] = partial( - fn, is_column=True, is_naive_2fuse=True) - base_actions[ - f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"] = partial( - fn, is_column=True, is_naive_2fuse=True) - base_actions[ - f"{base_model_prefix}.layers.0.mlp.down_proj.weight"] = ( - partial(fn, is_column=False)) - base_actions[ - f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"] = partial( - fn, is_column=False) - - for expert_idx in range(moe_num_experts): - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.experts.{expert_idx}.up_gate_proj.weight"] = partial( - fn, is_column=True, is_naive_2fuse=True) - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.experts.{expert_idx}.up_gate_proj.quant_weight"] = partial( - fn, is_column=True, is_naive_2fuse=True) - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.experts.{expert_idx}.down_proj.weight"] = partial( - fn, is_column=False) - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.experts.{expert_idx}.down_proj.quant_weight"] = partial( - fn, is_column=False) - - if moe_num_shared_experts > 0: - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.shared_experts.up_gate_proj.weight"] = partial( - fn, is_column=True, is_naive_2fuse=True) - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial( - fn, is_column=True, is_naive_2fuse=True) - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.shared_experts.down_proj.weight"] = partial( - fn, is_column=False) - base_actions[ - f"{base_model_prefix}.layers.{moe_layer_start_index}" - f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial( - fn, is_column=False, is_naive_2fuse=True) - - for key, action in base_actions.items(): - if (f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight" - in key or - f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight" - in key - or f"{base_model_prefix}.layers.0.mlp.down_proj.weight" - in key or - f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight" - in key): - for i in range(moe_layer_start_index): - final_actions[key.replace("layers.0.", - f"layers.{i}.")] = action - elif f"layers.{moe_layer_start_index}.mlp.experts." in key: - for i in range(moe_layer_start_index, num_layers): - final_actions[key.replace( - f"layers.{moe_layer_start_index}.", - f"layers.{i}.")] = action - elif f"layers.{moe_layer_start_index}.mlp.shared_experts." in key: - for i in range(moe_layer_start_index, num_layers): - final_actions[key.replace( - f"layers.{moe_layer_start_index}.", - f"layers.{i}.")] = action - elif f"{base_model_prefix}.layers.0." in key: - for i in range(num_layers): - final_actions[key.replace("layers.0.", - f"layers.{i}.")] = action - final_actions[key] = action - return final_actions - - moe_num_experts = 0 - moe_num_shared_experts = 0 - if isinstance(config.moe_num_experts, list): - moe_num_experts = sum(config.moe_num_experts) - elif isinstance(config.moe_num_experts, int): - moe_num_experts = config.moe_num_experts - if hasattr(config, 'moe_num_shared_experts'): - moe_num_shared_experts = config.moe_num_shared_experts - - moe_layer_start_index = -1 - if isinstance(config.moe_layer_start_index, list): - moe_layer_start_index = min(config.moe_layer_start_index) - elif isinstance(config.moe_layer_start_index, int): - moe_layer_start_index = config.moe_layer_start_index - - mappings = get_tensor_parallel_split_mappings( - config.num_layers, - moe_num_experts, - moe_num_shared_experts, - moe_layer_start_index, - ) - - return mappings - - class Ernie4_5_MLP(nn.Layer): def __init__( @@ -329,6 +51,7 @@ class Ernie4_5_MLP(nn.Layer): fd_config: FDConfig, intermediate_size: int, prefix: str = "", + reduce_results: bool = True, ) -> None: super().__init__() self.nranks = fd_config.parallel_config.tensor_parallel_degree @@ -345,7 +68,7 @@ class Ernie4_5_MLP(nn.Layer): self.down_proj = RowParallelLinear( fd_config=fd_config, prefix=f"{prefix}.down_proj", - input_size=(intermediate_size // self.nranks), + input_size=intermediate_size, output_size=fd_config.model_config.hidden_size, with_bias=False, ) @@ -423,8 +146,8 @@ class Ernie4_5_MoE(nn.Layer): f"{prefix}.experts.{{}}.down_proj.code_zp", } elif moe_quant_type == "tensor_wise_fp8" or ( - moe_quant_type == "block_wise_fp8" and - fd_config.model_config.is_quantized): + moe_quant_type == "block_wise_fp8" + and fd_config.model_config.is_quantized): weight_key_map = { "gate_weight_key": f"{prefix}.gate.weight", @@ -492,8 +215,6 @@ class Ernie4_5_Attention(nn.Layer): prefix: str) -> None: super().__init__() - nranks = fd_config.parallel_config.tensor_parallel_degree - self.qkv_proj = QKVParallelLinear( fd_config=fd_config, prefix=f"{prefix}.qkv_proj", @@ -502,8 +223,8 @@ class Ernie4_5_Attention(nn.Layer): self.o_proj = RowParallelLinear( fd_config=fd_config, prefix=f"{prefix}.o_proj", - input_size=(fd_config.model_config.head_dim * - fd_config.model_config.num_attention_heads // nranks), + input_size=fd_config.model_config.head_dim * + fd_config.model_config.num_attention_heads, output_size=fd_config.model_config.hidden_size, ) self.attn = Attention( @@ -636,12 +357,12 @@ class Ernie4_5_Model(nn.Layer): params_dtype=paddle.get_default_dtype(), prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens")) - self.hidden_layers = [ + self.hidden_layers = nn.LayerList([ Ernie4_5_DecoderLayer( fd_config=fd_config, prefix=f"{fd_config.model_config.prefix_name}.layers.{i}") for i in range(self.num_layers) - ] + ]) self.norm = RMSNorm( fd_config, @@ -772,3 +493,134 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM): Model Architecture Name """ return "Ernie4_5_ForCausalLM" + + +class Ernie4_5_PretrainedModel(PretrainedModel): + """ + Ernie4_5_PretrainedModel + """ + + config_class = FDConfig + + def _init_weight(self, layer): + """ + _init_weight + """ + return None + + weight_infos = [ + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight", + True, tsm.GQA), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", + False), + WeightMeta( + f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight", + True, tsm.PairFused), + WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight", + False), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.weight", + True, tsm.PairFused), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.weight", + False), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight", + True, tsm.PairFused), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight", + False), + WeightMeta(".embed_tokens.weight", False), + WeightMeta("lm_head.weight", True), + # quant tensorwise + WeightMeta( + f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.quant_weight", + True, tsm.GQA), + WeightMeta( + f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.quant_weight", + False), + WeightMeta( + f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.quant_weight", + True, tsm.PairFused), + WeightMeta( + f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.quant_weight", + False), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.quant_weight", + True, tsm.PairFused), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.quant_weight", + False), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.quant_weight", + True, tsm.PairFused), + WeightMeta( + f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.quant_weight", + False), + ] + + @classmethod + def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True): + """ + get_tensor_parallel_mappings + """ + logger.info("erine inference model _get_tensor_parallel_mappings") + from fastdeploy.model_executor.models.tp_utils import ( + build_expanded_keys, has_prefix, split_or_merge_func_v1) + + fn = split_or_merge_func_v1( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim) + + def get_tensor_parallel_split_mappings(num_layers, moe_num_experts, + moe_layer_start_index, + prefix_name): + base_actions = {} + weight_infos = cls.weight_infos + for (weight_name, is_column, extra) in weight_infos: + params = { + "is_column": is_column, + **({ + extra.value: True + } if extra else {}) + } + + if "lm_head.weight" in weight_name: + key = weight_name + elif not has_prefix(prefix_name, weight_name): + key = f"{prefix_name}{weight_name}" + else: + key = weight_name + base_actions[key] = partial(fn, **params) + final_actions = {} + start_layer = (moe_layer_start_index + if moe_layer_start_index > 0 else num_layers) + final_actions = build_expanded_keys( + num_layers, + moe_num_experts, + start_layer, + base_actions, + ) + return final_actions + + moe_num_experts = 0 + if isinstance(config.moe_num_experts, list): + moe_num_experts = sum(config.moe_num_experts) + elif isinstance(config.moe_num_experts, int): + moe_num_experts = config.moe_num_experts + + moe_layer_start_index = -1 + if isinstance(config.moe_layer_start_index, list): + moe_layer_start_index = min(config.moe_layer_start_index) + elif isinstance(config.moe_layer_start_index, int): + moe_layer_start_index = config.moe_layer_start_index + + mappings = get_tensor_parallel_split_mappings(config.num_layers, + moe_num_experts, + moe_layer_start_index, + config.prefix_name) + return mappings diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index 84c940b92..0a08e4b56 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -265,12 +265,12 @@ class Ernie4_5_MTPModel(nn.Layer): self.num_layers = fd_config.model_config.num_layers self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings - self.hidden_layers = [ + self.hidden_layers = nn.LayerList([ Ernie4_5_DecoderLayer( fd_config=fd_config, prefix=f"{fd_config.model_config.prefix_name}.{i}") for i in range(self.num_layers) - ] + ]) self.enorm = RMSNorm( fd_config, diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 603b14a8e..a08433a57 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -25,6 +25,8 @@ from paddle import nn from paddleformers.utils.log import logger from fastdeploy.config import FDConfig +from fastdeploy.distributed.communication_op import \ + tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE @@ -66,6 +68,7 @@ class Ernie4_5_VLMoE(nn.Layer): prefix: str) -> None: super().__init__() + self.tp_size = fd_config.parallel_config.tensor_parallel_degree moe_layer_start_index = fd_config.moe_config.moe_layer_start_index if isinstance(moe_layer_start_index, int): text_moe_layer_start_index = moe_layer_start_index @@ -99,6 +102,7 @@ class Ernie4_5_VLMoE(nn.Layer): } self.mlp_text = FusedMoE( fd_config=fd_config, + reduce_results=False, moe_intermediate_size=fd_config.moe_config. moe_intermediate_size[0], num_experts=fd_config.moe_config.num_experts[0], @@ -130,6 +134,7 @@ class Ernie4_5_VLMoE(nn.Layer): } self.mlp_image = FusedMoE( fd_config=fd_config, + reduce_results=False, moe_intermediate_size=fd_config.moe_config. moe_intermediate_size[1], num_experts=fd_config.moe_config.num_experts[1], @@ -154,6 +159,7 @@ class Ernie4_5_VLMoE(nn.Layer): intermediate_size=self.num_shared_experts * fd_config.moe_config.moe_intermediate_size[0], prefix=f"{prefix}.shared_experts", + reduce_results=False, ) def extract_gate_correction_bias_text(self, gate_correction_bias_key, @@ -210,6 +216,8 @@ class Ernie4_5_VLMoE(nn.Layer): hidden_states = self.mlp_text(hidden_states) if self.num_shared_experts > 0: hidden_states += share_experts_out + if self.tp_size > 1: + tensor_model_parallel_all_reduce(hidden_states) return hidden_states @@ -337,12 +345,12 @@ class Ernie4_5_VLModel(nn.Layer): prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"), ) - self.hidden_layers = [ + self.hidden_layers = nn.LayerList([ Ernie4_5_VLDecoderLayer( fd_config=fd_config, prefix=f"{fd_config.model_config.prefix_name}.layers.{i}") for i in range(self.num_layers) - ] + ]) self.norm = RMSNorm( fd_config, diff --git a/fastdeploy/model_executor/models/model_base.py b/fastdeploy/model_executor/models/model_base.py index be6ba470c..2e252fd0e 100644 --- a/fastdeploy/model_executor/models/model_base.py +++ b/fastdeploy/model_executor/models/model_base.py @@ -28,15 +28,17 @@ class ModelRegistry: _registry = {} @classmethod - def register(cls, model_class): + def register(cls, model_class, suffix=""): + """register model class""" if issubclass( model_class, ModelForCasualLM) and model_class is not ModelForCasualLM: - cls._registry[model_class.name()] = model_class + cls._registry[f"{model_class.name()}{suffix}"] = model_class return model_class @classmethod def get_class(cls, name): + """get model class""" if name not in cls._registry: raise ValueError(f"Model '{name}' is not registered!") return cls._registry[name] diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 27d83b09a..a8e6955db 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -61,7 +61,7 @@ class Qwen2MLP(nn.Layer): self.down_proj = RowParallelLinear( fd_config=fd_config, prefix=f"{prefix}.down_proj", - input_size=(fd_config.model_config.ffn_hidden_size // self.nranks), + input_size=fd_config.model_config.ffn_hidden_size, output_size=fd_config.model_config.hidden_size, with_bias=False, ) @@ -97,8 +97,6 @@ class Qwen2Attention(nn.Layer): prefix: str = "") -> None: super().__init__() - nranks = fd_config.parallel_config.tensor_parallel_degree - self.qkv_proj = QKVParallelLinear(fd_config=fd_config, prefix=f"{prefix}.qkv_proj", with_bias=True) @@ -106,7 +104,7 @@ class Qwen2Attention(nn.Layer): self.o_proj = RowParallelLinear( fd_config=fd_config, prefix=f"{prefix}.o_proj", - input_size=(fd_config.model_config.hidden_size // nranks), + input_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size, ) diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index b3d0ea405..0c5ecc96f 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -68,7 +68,7 @@ class Qwen3Attention(nn.Layer): fd_config=fd_config, prefix=f"{prefix}.o_proj", input_size=fd_config.model_config.head_dim * - fd_config.model_config.num_attention_heads // nranks, + fd_config.model_config.num_attention_heads, output_size=fd_config.model_config.hidden_size, ) diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 065647aca..f73063516 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -63,7 +63,7 @@ class Qwen3MLP(nn.Layer): self.down_proj = RowParallelLinear( fd_config, prefix=f"{prefix}.down_proj", - input_size=(fd_config.model_config.ffn_hidden_size // self.nranks), + input_size=fd_config.model_config.ffn_hidden_size, output_size=fd_config.model_config.hidden_size, with_bias=False, ) @@ -111,7 +111,7 @@ class Qwen3Attention(nn.Layer): fd_config, prefix=f"{prefix}.o_proj", input_size=fd_config.model_config.head_dim * - fd_config.model_config.num_attention_heads // nranks, + fd_config.model_config.num_attention_heads, output_size=fd_config.model_config.hidden_size, ) diff --git a/fastdeploy/model_executor/models/tp_utils.py b/fastdeploy/model_executor/models/tp_utils.py new file mode 100644 index 000000000..f360c5106 --- /dev/null +++ b/fastdeploy/model_executor/models/tp_utils.py @@ -0,0 +1,405 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import re +from enum import Enum +from functools import partial +from typing import Dict, List + +import numpy as np +import paddle +from paddleformers.transformers import PretrainedModel +from paddleformers.transformers.conversion_utils import split_or_merge_func +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.models.utils import LayerIdPlaceholder + + +def check_tensor_parallel_prerequisites( + fd_config: FDConfig, + cls: PretrainedModel, + tensor_parallel_filtered_map: Dict[str, partial], + safetensor_keys: List[str], +) -> None: + """check_tensor_parallel_prerequisites""" + if fd_config.parallel_config.tensor_parallel_degree > 1: + tensor_parallel_map = cls._get_tensor_parallel_mappings( + fd_config.model_config, is_split=True) + if not tensor_parallel_map: + logger.error("filtered_quant_map should not be empty. \ + parallel splitting required, but _get_tensor_parallel_mappings is not implemented." + ) + filtered_tp_keys = cls._resolve_prefix_keys(tensor_parallel_map.keys(), + safetensor_keys) + for k, v in filtered_tp_keys.items(): + tensor_parallel_filtered_map[v] = tensor_parallel_map.pop(k) + if not tensor_parallel_filtered_map: + logger.error("tensor_parallel_filtered_map should not be empty. \ + The weights required for tensor parallel splitting are inconsistent with the model's weights." + ) + + +def extract_prefix(weight_name: str) -> str: + """extract_prefix""" + if weight_name.startswith("."): + return "" + parts = weight_name.split(".", 1) + return parts[0] if len(parts) > 1 else "" + + +def has_prefix(prefix_name: str, weight_name: str): + """has_prefix""" + return prefix_name == extract_prefix(weight_name) + + +class TensorSplitMode(Enum): + """TensorSplitMode""" + + GQA = "is_gqa" + TRANSPOSE = "transpose" + QKV = "is_old_qkv" + PairFused = "is_naive_2fuse" + TripletFused = "is_naive_3fuse" + + +def extract_placeholders(template: str): + """extract_placeholders""" + return set(re.findall(r"{(\w+)}", template)) + + +class SafeDict(dict): + """SafeDict""" + + def __missing__(self, key): + return "{" + key + "}" + + +def has_placeholders(placeholders): + """has_placeholders""" + return len(placeholders) > 0 + + +def update_final_actions(params, final_actions, key, action): + """update_final_actions""" + new_key = key.format_map(SafeDict(params)) + final_actions[new_key] = action + + +def build_expanded_keys(num_layers, num_experts, start_layer, base_actions): + """build_expanded_keys""" + final_actions = {} + for key, action in base_actions.items(): + placeholders = extract_placeholders(key) + if not has_placeholders(placeholders): + final_actions[key] = action + else: + if LayerIdPlaceholder.LAYER_ID.value in placeholders: + for layer_id in range(num_layers): + update_final_actions( + {LayerIdPlaceholder.LAYER_ID.value: layer_id}, + final_actions, + key, + action, + ) + elif LayerIdPlaceholder.FFN_LAYER_ID.value in placeholders: + for layer_id in range(start_layer): + update_final_actions( + {LayerIdPlaceholder.FFN_LAYER_ID.value: layer_id}, + final_actions, + key, + action, + ) + elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders + and LayerIdPlaceholder.EXPERT_ID.value in placeholders): + for layer_id in range(start_layer, num_layers): + for export_id in range(num_experts): + update_final_actions( + { + LayerIdPlaceholder.MOE_LAYER_ID.value: + layer_id, + LayerIdPlaceholder.EXPERT_ID.value: export_id, + }, + final_actions, + key, + action, + ) + elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders + and len(placeholders) == 1): + for layer_id in range(start_layer, num_layers): + update_final_actions( + {LayerIdPlaceholder.MOE_LAYER_ID.value: layer_id}, + final_actions, + key, + action, + ) + else: + logger.error(f"{key} does not match any case.") + return final_actions + + +def gqa_qkv_split_func( + tensor_parallel_degree, + tensor_parallel_rank, + num_attention_heads, + num_key_value_heads, + head_dim, +): + """ + gqa_qkv_split_func + """ + + def fn(x, is_column=True): + """fucn""" + + def get_shape(tensor): + """get_shape""" + return tensor.get_shape() if hasattr(tensor, + "get_shape") else tensor.shape + + def slice_tensor(tensor, start, end): + """slice_tensor""" + shape = get_shape(tensor) + if len(shape) == 1: + return tensor[start:end] + elif is_column: + return tensor[..., start:end] + else: + return tensor[start:end, ...] + + q_end = num_attention_heads * head_dim + k_end = q_end + num_key_value_heads * head_dim + v_end = k_end + num_key_value_heads * head_dim + + q = slice_tensor(x, 0, q_end) + k = slice_tensor(x, q_end, k_end) + v = slice_tensor(x, k_end, v_end) + + def split_tensor(tensor, degree): + """ + split_tensor + """ + shape = get_shape(tensor) + size = shape[-1] if is_column else shape[0] + block_size = size // degree + if hasattr(tensor, "get_shape"): + return [ + slice_tensor(tensor, i * block_size, (i + 1) * block_size) + for i in range(degree) + ] + else: + if isinstance(x, paddle.Tensor): + if is_column: + return paddle.split(tensor, degree, axis=-1) + else: + return paddle.split(tensor, degree, axis=0) + else: + if is_column: + return np.split(tensor, degree, axis=-1) + else: + return np.split(tensor, degree, axis=0) + + q_list = split_tensor(q, tensor_parallel_degree) + k_list = split_tensor(k, tensor_parallel_degree) + v_list = split_tensor(v, tensor_parallel_degree) + + if tensor_parallel_rank is None: + res = [] + for q_i, k_i, v_i in zip(q_list, k_list, v_list): + if is_column: + if isinstance(x, paddle.Tensor): + res.append(paddle.concat([q_i, k_i, v_i], axis=-1)) + else: + res.append(np.concatenate([q_i, k_i, v_i], axis=-1)) + else: + if isinstance(x, paddle.Tensor): + res.append(paddle.concat([q_i, k_i, v_i], axis=0)) + else: + res.append(np.concatenate([q_i, k_i, v_i], axis=0)) + return res + else: + if isinstance(x, paddle.Tensor): + if is_column: + return paddle.concat( + [ + q_list[tensor_parallel_rank], + k_list[tensor_parallel_rank], + v_list[tensor_parallel_rank], + ], + axis=-1, + ) + else: + return paddle.concat( + [ + q_list[tensor_parallel_rank], + k_list[tensor_parallel_rank], + v_list[tensor_parallel_rank], + ], + axis=0, + ) + else: + if is_column: + return np.concatenate( + [ + q_list[tensor_parallel_rank], + k_list[tensor_parallel_rank], + v_list[tensor_parallel_rank], + ], + axis=-1, + ) + else: + return np.concatenate( + [ + q_list[tensor_parallel_rank], + k_list[tensor_parallel_rank], + v_list[tensor_parallel_rank], + ], + axis=0, + ) + + return fn + + +def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim): + """ + gqa_qkv_merge_func + """ + + def fn(weight_list, is_column=True): + """fn""" + tensor_parallel_degree = len(weight_list) + num_attention_heads = num_attention_heads // tensor_parallel_degree + num_key_value_heads = num_key_value_heads // tensor_parallel_degree + + is_paddle_tensor = not isinstance(weight_list[0], np.ndarray) + + def get_shape(tensor): + """ + get_shape + """ + return tensor.get_shape() if hasattr(tensor, + "get_shape") else tensor.shape + + def slice_tensor(tensor, start, end): + """ + slice_tensor + """ + if len(get_shape(tensor)) == 1: + return tensor[start:end] + elif is_column: + return tensor[..., start:end] + else: + return tensor[start:end, ...] + + q_list, k_list, v_list = [], [], [] + + for weight in weight_list: + q_end = num_attention_heads * head_dim + k_end = q_end + num_key_value_heads * head_dim + v_end = k_end + num_key_value_heads * head_dim + + q = slice_tensor(weight, 0, q_end) + k = slice_tensor(weight, q_end, k_end) + v = slice_tensor(weight, k_end, v_end) + + q_list.append(q) + k_list.append(k) + v_list.append(v) + + merged = q_list + k_list + v_list + + if is_paddle_tensor: + if is_column: + tensor = paddle.concat(merged, axis=-1) + else: + tensor = paddle.concat(merged, axis=0) + if tensor.place.is_gpu_place(): + tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) + return tensor + else: + if is_column: + return np.concatenate(merged, axis=-1) + else: + return np.concatenate(merged, axis=0) + + return fn + + +def split_or_merge_qkv_func( + is_split, + tensor_parallel_degree, + tensor_parallel_rank, + num_attention_heads, + num_key_value_heads, + head_dim, +): + """ + split_or_merge_qkv_func + """ + if is_split: + return gqa_qkv_split_func( + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + ) + else: + return gqa_qkv_merge_func( + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + ) + + +def split_or_merge_func_v1( + is_split, + tensor_parallel_degree, + tensor_parallel_rank, + num_attention_heads=None, + num_key_value_heads=None, + head_dim=None, +): + """ + split_or_merge_func_v1 + """ + + def fn(x, **kwargs): + """func""" + is_gqa = kwargs.pop("is_gqa", False) + if is_gqa: + func = split_or_merge_qkv_func( + is_split=is_split, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + ) + is_column = kwargs.pop("is_column", True) + return func(x, is_column=is_column) + else: + func = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + num_attention_heads=num_attention_heads, + ) + is_column = kwargs.pop("is_column", True) + is_naive_2fuse = kwargs.pop("is_naive_2fuse", False) + return func(x, is_column=is_column, is_naive_2fuse=is_naive_2fuse) + + return fn diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 63bca4b30..350f10651 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -16,6 +16,7 @@ from __future__ import annotations +import enum import hashlib import json import os @@ -23,29 +24,47 @@ import random import re import struct from functools import partial +from typing import NamedTuple, Optional import numpy as np import paddle import paddle.distributed as dist from paddle.common_ops_import import convert_dtype from paddle.distributed import fleet -from paddleformers.transformers.model_utils import (_add_variant, - load_tp_checkpoint) +from paddleformers.transformers.model_utils import _add_variant from paddleformers.transformers.utils import paddleformers_load from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) from paddleformers.utils.log import logger -from safetensors import safe_open from tqdm import tqdm -from fastdeploy.config import ModelConfig - MAX_BSZ = 512 MAX_DRAFT_TOKENS = 6 +class LayerIdPlaceholder(str, enum.Enum): + """LayerIdPlaceholder""" + LAYER_ID = "layer_id" + FFN_LAYER_ID = "ffn_layer_id" + MOE_LAYER_ID = "moe_layer_id" + EXPERT_ID = "export_id" + + +class WeightMeta(NamedTuple): + """ + #tensor split parameters + + # weight_name: weight name + # is_column: whether to split by columns + # extra: optional flags like "is_naive_2fuse", "is_gqa", "is_naive_3fuse" + """ + weight_name: str + is_column: bool + extra: Optional[str] = None + + class UniqueIDGenerator: """ The generator for the export model id @@ -433,223 +452,6 @@ def calculate_effective_tokens(training_args, train_dataset, max_seq_len): return total_effective_tokens, total_tokens -def load_ep_checkpoint(model_path: str, - config: ModelConfig, - return_numpy: bool = False, - return_key_name: bool = True): - """ - load ep checkpoint - """ - # return_numpy=True cpu - # return_numpy=False gpu - with open(os.path.join(model_path, "model.safetensors.index.json"), - "r") as f: - weight_list = json.load(f)["weight_map"] - filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k} - num_local_ffn_keys = [] - - for i in range(config.moe_layer_start_index, config.num_layers): - for j in range( - config.num_experts_start_offset, - config.num_experts_start_offset + config.num_experts_per_rank, - ): - ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" - ffn2_key = ( - f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight") - - ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" - ffn2_quant_key = ( - f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight") - - ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" - ffn2_scale_key = ( - f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale") - num_local_ffn_keys.append(ffn1_key) - num_local_ffn_keys.append(ffn2_key) - num_local_ffn_keys.append(ffn1_quant_key) - num_local_ffn_keys.append(ffn2_quant_key) - num_local_ffn_keys.append(ffn1_scale_key) - num_local_ffn_keys.append(ffn2_scale_key) - - for k in num_local_ffn_keys: - if k in weight_list: - filtered_map[k] = weight_list[k] - - state_dict = {} - # Get all safetensor file paths that need to be opened - safetensor_paths = set(filtered_map.values()) - - # Open each safetensor file sequentially with progress bar - for safetensor_path in tqdm(safetensor_paths, - desc="Loading safetensor files", - unit="file"): - with safe_open(os.path.join(model_path, safetensor_path), - framework="np", - device="cpu") as f: - # Check if this file contains keys from filtered_map - for k in filtered_map: - if filtered_map[k] == safetensor_path and k in f.keys(): - weight = f.get_tensor(k) - if not return_numpy: - weight = paddle.Tensor(weight, zero_copy=True) - weight = weight._copy_to( - paddle.framework._current_expected_place(), False) - state_dict[k] = weight - return state_dict - - -def get_safetensor_file(model_path): - """ - get_safetensor_file - """ - with open(os.path.join(model_path, "model.safetensors.index.json"), - "r") as f: - weight_map = json.load(f)["weight_map"] - weight_files_in_index = set() - for weight_name in weight_map: - weight_files_in_index.add( - os.path.join(model_path, weight_map[weight_name])) - key_name_list = list(set(weight_map.keys())) - safetensor_list = list(weight_files_in_index) - safetensor_list.sort() - return key_name_list, safetensor_list - - -def safetensors_weights_iterator(safe_tensor_list: list[str], ): - """ - safetensors_weights_iterator - """ - for st_file in tqdm( - safe_tensor_list, - desc="Loading safetensors checkpoint shards", - ): - with safe_open(st_file, framework="np") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - - -def fastsafetensors_weights_iterator(safetensor_list: list[str]): - """ - fastsafetensors_weights_iterator - """ - from fastsafetensors import SafeTensorsFileLoader, SingleGroup - world_size = dist.get_world_size() - if world_size > 1: - dist.init_parallel_env() - pg = dist.get_group() - device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu" - else: - pg = SingleGroup() - device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda( - ) else "cpu" - - safetensor_files_sub_lists = [ - safetensor_list[i:i + world_size] - for i in range(0, len(safetensor_list), world_size) - ] - for st_file in tqdm( - safetensor_files_sub_lists, - desc="Loading fastsafetensors checkpoint shards", - ): - loader = SafeTensorsFileLoader(pg, - device, - nogds=True, - debug_log=False, - framework="paddle") - rank_file_map = {i: [f] for i, f in enumerate(st_file)} - loader.add_filenames(rank_file_map) - try: - fb = loader.copy_files_to_device() - try: - keys = list(fb.key_to_rank_lidx.keys()) - for k in keys: - t = fb.get_tensor(k) - yield k, t - finally: - fb.close() - finally: - loader.close() - - -def get_state_dict(model_path, config, use_fastsafetensor=False): - """ - get_state_dict - """ - state_dict = {} - _, safetensor_list = get_safetensor_file( - os.path.join(model_path, f"rank{config.tensor_parallel_rank}")) - if use_fastsafetensor: - weights_iterator = fastsafetensors_weights_iterator(safetensor_list) - else: - weights_iterator = safetensors_weights_iterator(safetensor_list) - - for name, weight in weights_iterator: - state_dict[name] = weight - return state_dict - - -def apply_quant(name_action_quant_mappings, key, tensor, state_dict): - """ - apply_quant - """ - if key in name_action_quant_mappings: - action = name_action_quant_mappings.pop(key) - quant_weight_tensor, weight_scale_tensor = action(tensor) - if quant_weight_tensor is not None and weight_scale_tensor is not None: - state_dict[key + ".quant_weight"] = quant_weight_tensor - state_dict[key + ".weight_scale"] = weight_scale_tensor - else: - state_dict[key] = quant_weight_tensor - else: - state_dict[key] = tensor - - -def load_checkpoint(model_path, cls, config, return_numpy=True, load_gpu=True): - """ - load checkpoint - """ - if getattr(config, "parallel_config", None) is not None: - use_ep = getattr(config.parallel_config, "use_ep", False) - tensor_parallel_degree = config.parallel_config.tensor_parallel_degree - else: - use_ep = getattr(config, "use_ep", False) - tensor_parallel_degree = config.tensor_parallel_degree - - if getattr(config, "model_config", None) is not None: - model_config = config.model_config - else: - model_config = config - - if use_ep: - state_dict = load_ep_checkpoint(model_path, - config, - return_numpy=True, - return_key_name=True) - else: - rank_dirs = [ - f for f in os.listdir(model_path) if f.startswith("rank") - and os.path.isdir(os.path.join(model_path, f)) - ] - if len(rank_dirs) > 1: - if tensor_parallel_degree != len(rank_dirs): - raise ValueError( - f"Your model only supports loading with tp{len(rank_dirs)}" - ) - state_dict = get_state_dict(model_path, model_config) - else: - state_dict = load_tp_checkpoint(model_path, - cls, - model_config, - return_numpy=return_numpy) - import re - for k, v in state_dict.items(): - match = re.search(r'layers\.(\d+)', k) - if match and int(match.group(1)) > 0: - continue - return state_dict - - def parser_quant_type(quant_type): """ Parse the quantization type string and return the corresponding quantization types for weights, diff --git a/fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py b/fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py new file mode 100644 index 000000000..6681d752f --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py @@ -0,0 +1,354 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib +import inspect +import os +import re +import sys + +import paddle +import triton + +from .triton_utils import (SubstituteTemplate, build_package, compile_file, + extract_triton_kernel, find_so_path, + get_pointer_hint, link_file, multi_process_do, + python_path, rename_c_to_cu) + + +def get_value_hint(x): + """ + Get the value hint from input list. + """ + hint = "" + for ele in x: + if isinstance(ele, int): + hint += "i64," + continue + if ele % 16 == 0 and ele > 0: + hint += "i64:16," + elif ele == 1: + hint += "i64:1," + else: + hint += "i64," + if isinstance(ele, float): + hint += "fp32," + return hint + + +common_template = (""" +#include "${op_name}_kernel.h" +#include "paddle/extension.h" + +void ${op_name}_func(${tensor_and_attr}) { + auto run_stream = a_ptr->stream(); + auto res_flag = ${op_name}_kernel(run_stream, ${triton_kernel_args}, 0); + if (res_flag == CUDA_ERROR_INVALID_VALUE) { + PD_THROW("${op_name}_kernel failed"); + } +} + +PYBIND11_MODULE(${op_name}_package, m) { + + m.def("${op_name}_func", ${op_name}_func, "get expert token num"); +} + +""") + + +class KernelInterface: + """ + triton kernel interface. + """ + + def __init__( + self, + func, + other_config, + key_args=["1"], + ): + """ + triton kernel interface. + """ + self.func = func + self.key_args = key_args + + signature = inspect.signature(func) + self.arg_names = [v.name for v in signature.parameters.values()] + for ele in self.arg_names: + assert self.arg_names.count(ele) == 1 + # arg_defaults = [v.default for v in signature.parameters.values()] + + # self.annotations = { + # name: ty for name, ty in func.__annotations__.items() + # } + self.annotations = dict(func.__annotations__) + + self.constexprs = [ + self.arg_names.index(name) for name in self.arg_names + if self.annotations.get(name) == triton.language.core.constexpr + ] + + self.arg_exclude_constexpr = [ + self.arg_names[i] for i in range(len(self.arg_names)) + if i not in self.constexprs + ] + + import textwrap + + py_script = textwrap.dedent(inspect.getsource(func)) + + pat = r"def\s" + func.__name__ + func_begin = re.findall(pat, py_script) + assert len(func_begin) == 1 + func_begin = func_begin[0] + py_script = py_script[py_script.find(func_begin):] + + self.func_map = {} + + def decorator(*args, **kwargs): + """ + decorator for triton kernels. + Args: + *args: positional arguments + **kwargs: keyword arguments + """ + op_name = "haha" + str(kwargs["N"]) + if op_name in self.func_map.keys(): + return self.func_map[op_name](*args) + + all_input = [] + + for i in range(len(args)): + all_input.append(args[i]) + + position_arguments_num = len(all_input) + for i in range(position_arguments_num, len(self.arg_names)): + if self.arg_names[i] in kwargs.keys(): + all_input.append(kwargs[self.arg_names[i]]) + else: + # means this input is not specified, it muse be a tl.constexpr. + assert i in self.constexprs + all_input.append(None) + + dtypes = [] + x_list = [] + const_args = [self.arg_names[i] for i in self.constexprs] + + decalare_arg_exclude_constexpr = list(self.arg_exclude_constexpr) + passed_arg_exclude_constexpr = list(self.arg_exclude_constexpr) + + const_hint_dict = {} + for i in range(len(all_input)): + ele = all_input[i] + + if type(ele) in [ + paddle.Tensor, paddle.base.framework.EagerParamBase, + paddle.base.framework.Parameter, + paddle.base.framework.Variable, + paddle.base.libpaddle.pir.Value, + type(None) + ]: + if ele is not None: + dtypes.append(ele.dtype) + passed_arg_exclude_constexpr[ + i] = f"(CUdeviceptr)({passed_arg_exclude_constexpr[i]}->data())" + else: + dtypes.append(paddle.int8) + passed_arg_exclude_constexpr[ + i] = "(CUdeviceptr)(nullptr)" + decalare_arg_exclude_constexpr[ + i] = "const paddle::optional&" + decalare_arg_exclude_constexpr[ + i] + elif i in self.constexprs: + if isinstance(ele, bool): + const_hint_dict[self.arg_names[i]] = (int)(ele) + elif isinstance(ele, int): + if ele < 0: + const_hint_dict[self.arg_names[i]] = 0 + else: + const_hint_dict[self.arg_names[i]] = ele + else: + assert False + else: + x_list.append(ele) + if isinstance(ele, int): + decalare_arg_exclude_constexpr[ + i] = "const int64_t " + decalare_arg_exclude_constexpr[ + i] + elif isinstance(ele, float): + decalare_arg_exclude_constexpr[ + i] = "const float " + decalare_arg_exclude_constexpr[ + i] + else: + assert False + + python_package_name = f"{op_name}_package" + tp_rank = paddle.distributed.get_rank() + + generated_dir = os.getenv("TRITON_KERNEL_CACHE_DIR", + f"/tmp/triton_cache/rank{tp_rank}") + print("the kernel cache dir is:", generated_dir) + generated_dir = f"{generated_dir}/{op_name}" + os.makedirs(generated_dir, exist_ok=True) + + py_script_file = f"{generated_dir}/triton_kernels.py" + extract_triton_kernel(func, py_script_file) + + address_hint = get_pointer_hint(dtypes) + value_hint = get_value_hint(x_list) + const_args = [f"{{{ele}}}" for ele in const_args] + const_args = ",".join(const_args) + + lanuch_grid = list(self.grid) + for i in range(len(lanuch_grid)): + ele = lanuch_grid[i] + if isinstance(ele, str): + keys = list(const_hint_dict.keys()) + keys.sort(key=len, reverse=True) + for key in keys: + if key in ele: + ele = ele.replace(key, f"{const_hint_dict[key]}") + else: + ele = str(ele) + lanuch_grid[i] = ele + + if len(lanuch_grid) < 3: + lanuch_grid += ["1"] * (3 - len(lanuch_grid)) + lanuch_grid = ",".join(lanuch_grid) + + op_dict = {"op_name": op_name} + op_dict["triton_kernel_args"] = ",".join( + passed_arg_exclude_constexpr) + op_dict["tensor_and_attr"] = ",".join( + decalare_arg_exclude_constexpr) + + paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu" + so_path = find_so_path(generated_dir, python_package_name) + + if so_path is None: + print("== we do not find so_path, we need to compile it") + with open(paddle_custom_op_file_path, "w") as f: + f.write(SubstituteTemplate( + common_template, + op_dict, + )) + f.close() + + # ahead of time compile command. + aot_template = ( + f"""{python_path} {compile_file} {py_script_file} """ + + f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """ + + f"""--out-name {op_name}_kernel """ + + """ -w {num_warps} -ns {num_stages} """ + + f""" -s"{address_hint} {value_hint} {const_args}" """ + + f""" -g "{lanuch_grid}" """) + + all_tune_config = [const_hint_dict] + # reset const_hint_dict as empty. + const_hint_dict = {} + codegen_commands = [] + for config in all_tune_config: + for key in const_hint_dict.keys(): + if const_hint_dict[key] is not None: + if key not in config.keys(): + config[key] = const_hint_dict[key] + else: + if config[key] == const_hint_dict[key]: + pass + else: + message = ( + f"you specify {key} both in arguments and config, " + "and they are not same, this is wrong." + ) + raise ValueError(message) + else: + assert key in config.keys( + ), f"you must specify {key} in your config." + if "num_warps" not in config.keys(): + config["num_warps"] = 4 + if "num_stages" not in config.keys(): + config["num_stages"] = 4 + + for key in config: + assert config[ + key] is not None, f"{key} must be specified." + codegen_command = aot_template.format(**config, ) + print(codegen_command) + codegen_commands.append(codegen_command) + multi_process_do(codegen_commands) + + link_command = ( + f"{python_path} {link_file} " + f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel") + re = os.system(link_command) + assert re == 0 + + # rename the .c file to .cu + rename_c_to_cu(generated_dir) + # build the package to so, not install + build_package(generated_dir, python_package_name) + + # so_path have be found! + so_path = find_so_path(generated_dir, python_package_name) + print("== we find so_path: ", so_path) + assert so_path is not None + dir_path = os.path.dirname(so_path) + sys.path.append(dir_path) + lib = importlib.import_module(python_package_name) + pybind_func = getattr(lib, f"{op_name}_func") + self.func_map[op_name] = pybind_func + + # run this op! + self.func_map[op_name](*args) + + self.decorator = decorator + + def __getitem__(self, op_name_and_grid): + """ + override the operator [], which will call the decorator function. + Args: + op_name_and_grid: the name of the operator and the grid size. + Returns: + the decorator function. + """ + self.grid = (( + "((max_possible_num_post_padded + BLOCK_SIZE_M -1)/ BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N-1) / BLOCK_SIZE_N)" + ), ) + + return self.decorator + + +def paddle_use_triton_v2(other_config={}, key=[]): + """ + The decorator function that wraps the original function. + Args: + func: the original function. + Returns: + the wrapped function. + """ + + def decorator(func): + """ + The decorator function that wraps the original function. + Args: + func: the original function. + Returns: + the wrapped function. + """ + return KernelInterface(func, other_config, key) + + return decorator diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 347a62d84..8f3601ff6 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -17,6 +17,7 @@ from typing import Dict, Optional import paddle +from fastdeploy import envs from fastdeploy.engine.config import SpeculativeConfig from fastdeploy.model_executor.ops.gpu import ( get_padding_offset, save_output, set_stop_value_multi_ends, @@ -24,10 +25,11 @@ from fastdeploy.model_executor.ops.gpu import ( speculate_get_padding_offset, speculate_get_seq_lens_output, speculate_save_output, speculate_set_value_by_flags_and_idx, speculate_step_paddle, speculate_step_system_cache, speculate_update_v3, - step_paddle, step_system_cache, update_inputs) + step_paddle, step_system_cache, update_inputs, step_reschedule) from fastdeploy.platforms import current_platform from fastdeploy.worker.output import ModelOutputData +DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1") def pre_process( max_len: int, @@ -214,6 +216,8 @@ def step_cuda( """ TODO(gongshaotian): normalization name """ + + if speculative_config.method is not None: if enable_prefix_caching: speculate_step_system_cache( @@ -291,6 +295,33 @@ def step_cuda( share_inputs["input_ids"], share_inputs["pre_ids"], share_inputs["step_idx"], share_inputs["next_tokens"], share_inputs["first_token_ids"], block_size, enc_dec_block_num) + elif DISABLE_RECOVER: + step_reschedule( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) else: step_paddle( share_inputs["stop_flags"], diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 00f32c4dc..085c1ab5b 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -341,7 +341,6 @@ class TokenProcessor(object): result.prompt = task.prompt result.prompt_token_ids = task.prompt_token_ids if recovery_stop: - result.outputs.token_ids.append(task.eos_token_ids[0]) result.error_msg = "Recover is not supported, the result is incomplete!" llm_logger.info( f"Request: {task_id} finished, number of " diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index aa7a624cf..769410c29 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -15,11 +15,15 @@ platform interface file """ -import paddle import enum + +import paddle + + class _Backend(enum.Enum): NATIVE_ATTN = enum.auto() APPEND_ATTN = enum.auto() + MLA_ATTN = enum.auto() class Platform: @@ -71,8 +75,7 @@ class Platform: if self.supported_quantization and quant not in self.supported_quantization: raise ValueError( f"{quant} quantization is currently not supported in " - f"{self.device_name}." - ) + f"{self.device_name}.") @classmethod def available(self): diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index 91184aef0..f5b3082b5 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -46,7 +46,7 @@ class CUDAPlatform(Platform): return False @classmethod - def get_attention_backend_cls(cls, selected_backend): + def get_attention_backend_cls(cls, selected_backend: _Backend): """ get_attention_backend_cls """ @@ -60,5 +60,13 @@ class CUDAPlatform(Platform): return ( "fastdeploy.model_executor.layers.attention.AppendAttentionBackend" ) + elif selected_backend == _Backend.MLA_ATTN: + logger.info("Using MLA ATTN backend.") + return ( + "fastdeploy.model_executor.layers.attention.MLAAttentionBackend" + ) else: - logger.warning("Other backends are not supported for now.") + raise ValueError( + "Invalid attention backend you specified.\n" + "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place." + ) diff --git a/fastdeploy/reasoning/abs_reasoning_parsers.py b/fastdeploy/reasoning/abs_reasoning_parsers.py index f971f8865..f989547d9 100644 --- a/fastdeploy/reasoning/abs_reasoning_parsers.py +++ b/fastdeploy/reasoning/abs_reasoning_parsers.py @@ -13,15 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -import os from abc import abstractmethod from collections.abc import Sequence from functools import cached_property from typing import Callable, Optional, Union from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) -from fastdeploy.utils import data_processor_logger + DeltaMessage) from fastdeploy.utils import is_list_of @@ -120,7 +118,8 @@ class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} @classmethod - def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: + def get_reasoning_parser(cls, + name: Optional[str]) -> type[ReasoningParser]: """ Get reasoning parser by name which is registered by `register_module`. @@ -185,4 +184,4 @@ class ReasoningParserManager: cls._register_module(module=module, module_name=name, force=force) return module - return _register \ No newline at end of file + return _register diff --git a/fastdeploy/rl/__init__.py b/fastdeploy/rl/__init__.py new file mode 100644 index 000000000..f8f042712 --- /dev/null +++ b/fastdeploy/rl/__init__.py @@ -0,0 +1,20 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from fastdeploy.model_executor.models import auto_models_registry + +auto_models_registry(os.path.dirname(__file__), "fastdeploy.rl", suffix="RL") diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py new file mode 100644 index 000000000..ae613f55d --- /dev/null +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -0,0 +1,282 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import time +from multiprocessing.shared_memory import SharedMemory +from typing import Any, Dict, List + +import numpy as np +import paddle +from paddle import nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.load_weight_utils import \ + load_composite_checkpoint +from fastdeploy.model_executor.model_loader import MODEL_CLASSES + + +class DynamicWeightManager: + """Manages model weights loading, updating and shared state across processes.""" + + def __init__(self, fd_config: FDConfig, model: nn.Layer): + """Initialize with config and model instances.""" + self.fd_config = fd_config + self.load_config = fd_config.load_config + self.parallel_config = fd_config.parallel_config + self.state_dict: Dict[str, paddle.Tensor] = {} + self.rank = fd_config.parallel_config.tensor_parallel_rank + self.nranks = paddle.distributed.get_world_size() + self.meta_src_id = self._get_gpu_id() + self.first_load = True + self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}" + self.models: List[nn.Layer] = [model] + self._capture_model_state() + + if self.load_config.load_strategy != "meta": + self.update_parameters() + + logger.info( + f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, " + f" rank={self.rank}, ranks={self.nranks}, " + f" load ipc weight from {self.ipc_path}.") + + @paddle.no_grad() + def _capture_model_state(self): + """Capture and store initial model parameters state.""" + for model in self.models: + for name, param in model.state_dict().items(): + logger.debug( + f"Model param: {name}, shape={param.shape}, dtype={param.dtype}" + ) + self.state_dict[name] = param + + def add_model(self, model: nn.Layer): + """"add model""" + self.models.append(model) + self._capture_model_state() + + def update_parameters(self, pid: int = 0) -> None: + """Core method to update model parameters based on strategy.""" + start_time = time.perf_counter() + paddle.device.cuda.empty_cache() + + if not self.first_load: + paddle.distributed.restart_process_group() + + strategy_handlers = { + "ipc_snapshot": self._update_ipc_snapshot, + "ipc": self._update_ipc, + "ipc_no_reshard": self._update_ipc_no_reshard, + "normal": self.load_model, + } + + if handler := strategy_handlers.get(self.load_config.load_strategy): + handler() + else: + raise ValueError( + f"Unsupported strategy: {self.load_config.load_strategy}") + + logger.info( + f"Update parameters in {time.perf_counter()-start_time:.2f}s") + + self._finalize_update(pid) + + def _update_ipc_snapshot(self): + """Update using IPC snapshot strategy for elastic recovery.""" + model_path = os.path.join( + self.parallel_config.model_name_or_path, + f"model_state.tp0{self.meta_src_id}.pdparams") + + try: + ipc_state_dict = paddle.load(model_path) + except FileNotFoundError: + fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams" + ipc_state_dict = paddle.load(fallback_path) + + try: + self._update_model_from_state(ipc_state_dict, "snapshot") + except Exception: + self.models[0].set_state_dict(ipc_state_dict) + logger.warning( + "load model from no_reshard weight, maybe need more GPU memory" + ) + logger.info("IPC snapshot update parameters completed") + + def _update_ipc(self): + """Update using standard IPC strategy (requires Training Worker).""" + ipc_meta = paddle.load(self.ipc_path) + state_dict = self._convert_ipc_meta_to_tensor(ipc_meta) + self._update_model_from_state(state_dict, "raw") + logger.info("IPC update parameters completed") + + def _update_ipc_no_reshard(self): + """Update using no-reshard IPC strategy (faster but uses more memory).""" + ipc_meta = paddle.load(self.ipc_path) + state_dict = self._convert_ipc_meta_to_tensor(ipc_meta) + self.models[0].set_state_dict(state_dict) + logger.info("IPC no-reshard update parameters completed") + + def load_model(self) -> nn.Layer: + """Standard model loading without IPC.""" + architectures = self.fd_config.model_config.architectures[0] + model_class = MODEL_CLASSES[architectures] + state_dict = load_composite_checkpoint( + self.fd_config.parallel_config.model_name_or_path, + model_class, + self.fd_config.model_config, + return_numpy=True) + self.models[0].set_state_dict(state_dict) + logger.info("normal load update parameters completed") + + def clear_parameters(self, pid: int = 0) -> None: + """Clear all model parameters and free memory.""" + logger.info("start clear paramaters") + paddle.device.cuda.empty_cache() + for model in self.models: + for param in model.state_dict().values(): + param._clear_data() + + self._verify_parameters("clearance") + if self.nranks > 1: + paddle.distributed.barrier() + paddle.distributed.shutdown_process_group() + self._update_shared_status(pid, -2) + + def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], + src_type: str): + """Update model parameters from given state dictionary.""" + update_count = 0 + for name, new_param in state_dict.items(): + if name not in self.state_dict: + logger.debug(f"Ignoring unmatched {src_type} param: {name}") + continue + + target_param = self.state_dict[name] + self._validate_parameter_match(name, new_param, target_param) + new_param._share_buffer_to(target_param) + update_count += 1 + logger.info( + f"🆗 Updated {update_count}/{len(state_dict)} parameters from {src_type} source" + ) + + def _validate_parameter_match(self, name: str, src: paddle.Tensor, + dst: paddle.Tensor): + """验证参数一致性""" + if src.dtype != dst.dtype: + raise TypeError( + f"Type mismatch for {name}: {src.dtype} vs {dst.dtype}") + if src.shape != dst.shape: + raise ValueError( + f"Shape mismatch for {name}: {src.shape} vs {dst.shape}") + + def _finalize_update(self, pid: int): + """Finalize update process with verification.""" + self._verify_parameters("update") + if self.nranks > 1: + paddle.distributed.barrier() + if not self.first_load: + self._update_shared_status(pid, 0) + self.first_load = False + + def _get_gpu_id(self) -> int: + """Get current GPU device ID.""" + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0").split(",") + return int(visible_devices[int(os.getenv("FLAGS_selected_gpus", "0"))]) + + def _verify_parameters(self, operation: str): + """Verify parameters are in expected state after operation.""" + expected_initialized = (operation == "update") + all_valid = True + for name, param in self.state_dict.items(): + is_initialized = param._is_initialized() + if is_initialized != expected_initialized: + logger.error( + f"Verification failed after {operation}: " + f"Param {name} initialized={is_initialized} (expected {expected_initialized})" + ) + all_valid = False + + if all_valid: + logger.info(f"💡 Model Parameter {operation} verified successfully") + else: + raise RuntimeError( + f"❌ Model Parameter {operation} verification failed") + + @staticmethod + def _convert_ipc_meta_to_tensor( + ipc_meta: Dict[str, Any]) -> Dict[str, paddle.Tensor]: + """Convert IPC metadata to tensor dictionary.""" + converted = {} + for name, meta in ipc_meta.items(): + meta[0] = meta[0].encode("latin-1") + meta[6] = int(os.getenv("FLAGS_selected_gpus", "0")) + tensor = paddle.base.core.LoDTensor._new_shared_cuda(tuple(meta)) + converted[name] = paddle.to_tensor(tensor) + return converted + + def _log_memory(self, context: str): + """Log current GPU memory usage.""" + max_alloc = paddle.device.cuda.max_memory_allocated() / (1024**3) + max_reserved = paddle.device.cuda.max_memory_reserved() / (1024**3) + curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3) + curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3) + + logger.warning(f"GPU memory usage {context}:" + f"max_allocated: {max_alloc:.2f}GB\n" + f"max_reserved: {max_reserved:.2f}GB\n" + f"current_allocated: {curr_alloc:.2f}GB\n" + f"current_reserved: {curr_reserved:.2f}GB") + + def _update_shared_status(self, pid: int, status: int) -> None: + """Update shared memory status flag for inter-process communication.""" + array = np.zeros([1], dtype=np.int32) + shm = SharedMemory(create=False, + size=array.nbytes, + name=f"model_weights_status.{pid}") + value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf) + if self.rank == 0: + value[self.rank] = status + + @staticmethod + def check_model_weights_status(model_weights_status, model_runner, pid): + """ + check model weights status + """ + is_stop = 0 + while model_weights_status.value[0] != 0: + if model_weights_status.value[0] == 1: + logger.info( + "infer engine stopped! start to load new checkpoint...") + model_runner.update_parameters(pid) + elif model_weights_status.value[0] == -1: + logger.info( + "infer engine stopped! start to clear checkpoint...") + model_runner.clear_parameters(pid) + + while True: + if model_weights_status.value[0] == 0: + logger.info("finished loading new checkpoint") + break + elif is_stop == 1 or (model_weights_status.value[0] == -2 + and is_stop == 0): + if is_stop == 0: + logger.info("finished clearing checkpoint") + is_stop = 1 + time.sleep(0.001) + break + else: + time.sleep(0.001) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py new file mode 100644 index 000000000..53e453274 --- /dev/null +++ b/fastdeploy/rl/rollout_model.py @@ -0,0 +1,215 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Dict + +import paddle +from paddle import nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.model_loader import ModelRegistry +from fastdeploy.model_executor.models.ernie4_5_moe import \ + Ernie4_5_MoeForCausalLM +from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel +from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel +from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel + +RL_MODEL_CLASSES = { + "Ernie4_5_MoeForCausalLMRL": Ernie4_5_MoeForCausalLM, + "Qwen2ForCausalLMRL": Qwen2PretrainedModel, + "Qwen3ForCausalLMRL": Qwen3PretrainedModel, + "Qwen3MoeForCausalLMRL": Qwen3MoePretrainedModel, +} + + +class RollOutModel(nn.Layer): + """Main model class for rollout operations, supports multimodal components for train.""" + + def __init__(self, fd_config: FDConfig): + """Initialize with FastDeploy configuration.""" + super(RollOutModel, self).__init__() + self.fd_config = fd_config + self._init_models() + + def _init_models(self): + """Initialize all model components including multimodal if needed.""" + self.is_vl = "VL" in self.fd_config.model_config.architectures[0] + self.rollout_model = self._load_primary_model() + self.rollout_models = [self.rollout_model] + + if self.is_vl: + self._init_multimodal_models() + self.rollout_models.extend( + [self.vision_model, self.resampler_model]) + + def _init_multimodal_models(self): + """Initialize vision and resampler components for multimodal models.""" + # TODO:(gaoziyuan) Implement actual initialization + self.vision_model = nn.Layer() + self.resampler_model = nn.Layer() + + def _load_primary_model(self): + """Load main model from loader based on config.""" + if "VL" in self.fd_config.model_config.architectures[0]: + logger.error("Loaded Vision Language model, not support now") + + context = paddle.LazyGuard() + architectures = f"{self.fd_config.model_config.architectures[0]}RL" + with context: + model_cls = ModelRegistry.get_class(architectures) + model = model_cls(self.fd_config) + + model.eval() + return model + + def get_name_mappings_to_training(self) -> Dict[str, str]: + """Get parameter name mappings between rollout and training models.""" + mappings = {} + for model in self.rollout_models: + mappings.update( + getattr(model, "get_name_mappings_to_training", lambda: {})()) + return mappings + + @paddle.no_grad() + def state_dict(self): + """state_dict""" + all_params = {} + for model in self.rollout_models: + for name, param in model.state_dict().items(): + logger.debug( + f"Model param: {name}, shape={param.shape}, dtype={param.dtype}" + ) + all_params[name] = param + return all_params + + +class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): + """ + Ernie4_5_MoeForCausalLMRL + """ + + def __init__(self, fd_config: FDConfig): + """ + Args: + fd_config (FDConfig): Configurations for the LLM model. + """ + super(Ernie4_5_MoeForCausalLMRL, self).__init__(fd_config) + + @classmethod + def name(self): + """name""" + return "Ernie4_5_MoeForCausalLMRL" + + def get_name_mappings_to_training(self): + """Generate mapping between inference and training parameter for RL(donot delete!).""" + have_bias = self.fd_config.model_config.get("have_norm_bias", False) + # Prepare placeholders + place_holders = ["weight"] + (["bias"] if have_bias else []) + + # Initialize mapping dictionary + infer_to_train = {} + + # Static mappings (non-layer specific) + static_mappings = { + "model.embeddings.word_embeddings.weight": + "ernie.embed_tokens.weight", + "model.norm.ln_weight": "ernie.norm.weight", + "lm_head.out_linear.weight": "lm_head.weight" + } + if self.fd_config.model_config.get("weight_sharing", False): + # Support tie_word_embeddings + logger.debug("enable tie_word_embeddings") + static_mappings.pop("lm_head.out_linear.weight") + infer_to_train.update(static_mappings) + infer_base_name = "model.hidden_layers" + + # Helper function to add layer mappings + def _add_layer_mappings(layer_idx, is_moe_layer=False): + # Handle special case for layer 0's input layernorm + for ph in place_holders: + infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}" + train_key = f"ernie.layers.{layer_idx}.input_layernorm.{ph}" + infer_to_train[infer_key] = train_key + + # Common attention mappings + for ph in place_holders: + infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \ + f"ernie.layers.{layer_idx}.self_attn.qkv_proj.{ph}" + + infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \ + f"ernie.layers.{layer_idx}.self_attn.o_proj.{ph}" + + # Post-attention layernorm + for ph in place_holders: + infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \ + f"ernie.layers.{layer_idx}.post_attention_layernorm.{ph}" + + if not is_moe_layer: + # Dense FFN mappings + for ph in place_holders: + infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \ + f"ernie.layers.{layer_idx}.mlp.up_gate_proj.{ph}" + + infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \ + f"ernie.layers.{layer_idx}.mlp.down_proj.{ph}" + else: + # MoE specific mappings + infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \ + f"ernie.layers.{layer_idx}.mlp.gate.weight" + + if self.fd_config.moe_config.moe_use_aux_free: + infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \ + f"ernie.layers.{layer_idx}.mlp.moe_statics.e_score_correction_bias" + + # Support shared experts + if self.fd_config.model_config.get( + "moe_num_shared_experts") > 0: + infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \ + f"ernie.layers.{layer_idx}.mlp.shared_experts.up_gate_proj.weight" + infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \ + f"ernie.layers.{layer_idx}.mlp.shared_experts.down_proj.weight" + + # MoE experts mappings + for expert_idx in range(self.fd_config.moe_config.num_experts): + for ph in place_holders: + # FFN1 (up_gate_proj) + ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn1_weight" + if ffn1_key not in infer_to_train: + infer_to_train[ffn1_key] = [] + infer_to_train[ffn1_key].append( + f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" + ) + + # FFN2 (down_proj) + ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn2_weight" + if ffn2_key not in infer_to_train: + infer_to_train[ffn2_key] = [] + infer_to_train[ffn2_key].append( + f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" + ) + + # Process non-MoE layers + for layer_idx in range( + self.fd_config.moe_config.moe_layer_start_index): + _add_layer_mappings(layer_idx, is_moe_layer=False) + + # Process MoE layers + for layer_idx in range(self.fd_config.moe_config.moe_layer_start_index, + self.fd_config.model_config.num_layers): + _add_layer_mappings(layer_idx, is_moe_layer=True) + + return infer_to_train diff --git a/fastdeploy/scheduler/splitwise_scheduler.py b/fastdeploy/scheduler/splitwise_scheduler.py index 50cd652b9..be4534974 100644 --- a/fastdeploy/scheduler/splitwise_scheduler.py +++ b/fastdeploy/scheduler/splitwise_scheduler.py @@ -260,12 +260,13 @@ class ResultReader(object): ResultReader use an async thread to continue get infer result from redis """ - def __init__(self, client, idx, batch=200, ttl=900): + def __init__(self, client, idx, batch=200, ttl=900, group=""): self.idx = idx self.batch = batch self.client = client self.data = deque() self.ttl = ttl + self.group = group self.reqs = dict() self.out_buffer = dict() @@ -380,15 +381,18 @@ class ResultReader(object): fetch infer results from redis for the give keys """ total = 0 + if self.group != "": + keys = [self.group] for key in keys: + #logger.info(f"Sync Results from Redis {key}") results = self.client.rpop(key, self.batch) if results is None or len(results) == 0: continue - #logger.info(f"Rpop {self.idx}: {len(results)}") + #logger.info(f"Rpop {key} {self.idx}: {len(results)}") total += len(results) for result in results: try: - #logger.info(f"Scheduler Get Results: {result}") + # logger.info(f"Scheduler Get Results: {result.request_id}") data = orjson.loads(result) result = RequestOutput.from_dict(data) self.data.appendleft(result) @@ -425,8 +429,9 @@ class APIScheduler(object): start backup threads """ for i in range(self.reader_parallel): + group = f"{self.nodeid}-{i}" reader = ResultReader(self.client, i, self.reader_batch_size, - self.ttl) + self.ttl, group) self.readers.append(reader) self.clear_expired_nodes_thread = threading.Thread( @@ -481,15 +486,16 @@ class APIScheduler(object): reader = self.readers[reader_idx] reader.add_req(req) + group = self.readers[reader_idx].group reader_idx = (reader_idx + 1) % len(self.readers) - self.schedule(req, pnodes, dnodes, mnodes) + self.schedule(req, pnodes, dnodes, mnodes, group) except IndexError: continue except Exception as e: logger.error(f"APIScheduler Schedule req error: {str(e)}") - def schedule(self, req, pnodes, dnodes, mnodes): + def schedule(self, req, pnodes, dnodes, mnodes, group=""): """ schedule an req to according redis node queue """ @@ -498,7 +504,9 @@ class APIScheduler(object): pnode = self.select_pd(req, pnodes, "prefill") if pnode.role == "mixed": req.disaggregate_info = None - req_str = orjson.dumps(req.to_dict()) + req_dict = req.to_dict() + req_dict["group"] = group + req_str = orjson.dumps(req_dict) pkey = f"ReqQ_{pnode.nodeid}" #logger.info(f"Schedule Req {req_str} to Mixed") self.client.lpush(pkey, req_str) @@ -518,7 +526,9 @@ class APIScheduler(object): disaggregated["transfer_protocol"] = transfer_protocol[0] req.disaggregate_info = disaggregated pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}" - req_str = orjson.dumps(req.to_dict()) + req_dict = req.to_dict() + req_dict["group"] = group + req_str = orjson.dumps(req_dict) #logger.info(f"Schedule Req {req_str}") self.client.lpush(dkey, req_str) self.client.lpush(pkey, req_str) @@ -634,7 +644,9 @@ class ResultWriter(object): size = len(self.data) if size == 0: self.cond.wait() + #qsize = size size = min(size, self.batch) + #logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}") groups = dict() for i in range(size): key, item = self.data.pop() @@ -749,12 +761,13 @@ class InferScheduler(object): for req_str in reqs: req = orjson.loads(req_str) + group = req.get("group", "") req = Request.from_dict(req) writer_idx = select_writer(req) logger.info( f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}" ) - req.request_id = f"{req.request_id}#{writer_idx}" + req.request_id = f"{req.request_id}#{writer_idx}#{group}" if self.role == "prefill" or self.role == "mixed": self.reqs_queue.append(req) self.node.add_req(req.request_id, @@ -813,10 +826,10 @@ class InferScheduler(object): req_ids.add(result.request_id) - req_id, idx = result.request_id.split("#") + req_id, idx, group = result.request_id.split("#") result.request_id = req_id - key = (req_id, int(idx)) + key = (req_id if group == "" else group, int(idx)) if key not in groups: groups[key] = list() diff --git a/fastdeploy/start_splitwise.sh b/fastdeploy/start_splitwise.sh deleted file mode 100644 index d6e521dcc..000000000 --- a/fastdeploy/start_splitwise.sh +++ /dev/null @@ -1,15 +0,0 @@ - -export FLAGS_use_pd_disaggregation=1 - - -export INFERENCE_MSG_QUEUE_ID=1 -export FD_LOG_DIR="log_decode" -CUDA_VISIBLE_DEVICES=4,5,6,7 python fastdeploy.entrypoints.openai.api_server.py --config test.yaml --port 9812 --max-num-seqs 256 --kv-cache-ratio 0.8 --splitwise-role "decode" --engine-worker-queue-port 6678 --innode-prefill-ports 6677 --cache-queue-port 55667 --enable-prefix-caching --enable-chunked-prefill & - - -export FD_LOG_DIR="log_prefill" -export INFERENCE_MSG_QUEUE_ID=3 -export FLAGS_fmt_write_cache_completed_signal=1 -export PREFILL_NODE_ONE_STEP_STOP=1 -CUDA_VISIBLE_DEVICES=0,1,2,3 python fastdeploy.entrypoints.openai.api_server.py --config test.yaml --port 9811 --cpu-offload-gb 5 --max-num-seqs 16 --kv-cache-ratio 0.9 --splitwise-role "prefill" --engine-worker-queue-port 6677 --enable-prefix-caching --cache-queue-port 55663 & - diff --git a/fastdeploy/test.yaml b/fastdeploy/test.yaml index bcf7ad20b..1738b37e2 100644 --- a/fastdeploy/test.yaml +++ b/fastdeploy/test.yaml @@ -1,4 +1,4 @@ -model: "baidu/ERNIE-45-300B-A47B-Paddle" +model: "baidu/paddle_internal/ERNIE-45-Turbo" max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.5 diff --git a/fastdeploy/worker/eplb.py b/fastdeploy/worker/eplb.py index 4aca25e56..45ce85eac 100644 --- a/fastdeploy/worker/eplb.py +++ b/fastdeploy/worker/eplb.py @@ -84,9 +84,10 @@ def replicate_experts( return phy2log, rank, logcnt -def rebalance_experts_hierarchical(weight: np.ndarray, - num_physical_experts: int, num_groups: int, - num_nodes: int, num_gpus: int): +def rebalance_experts_hierarchical( + weight: np.ndarray, num_physical_experts: int, num_groups: int, + num_nodes: int, + num_gpus: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Parameters: weight: [num_moe_layers, num_logical_experts] diff --git a/fastdeploy/worker/experts_manager.py b/fastdeploy/worker/experts_manager.py index 79b8ba769..53bc0b725 100644 --- a/fastdeploy/worker/experts_manager.py +++ b/fastdeploy/worker/experts_manager.py @@ -1,7 +1,7 @@ """ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -""" -redundant expert manger -""" +"""redundant expert manger.""" +from typing import Optional, Tuple import numpy as np import paddle @@ -29,9 +28,9 @@ class RedundantExpertManger: RedundantExpertManger """ - def __init__(self, n_routed_experts, num_hidden_layers, - redundant_experts_num, ep_size): - + def __init__(self, n_routed_experts: int, num_hidden_layers: int, + redundant_experts_num: int, ep_size: int) -> None: + """Initialize a redundant expert manager""" self.num_expert = n_routed_experts self.redundant_experts_num = redundant_experts_num self.num_hidden_layers = num_hidden_layers @@ -94,7 +93,9 @@ class RedundantExpertManger: num_replicas {self.num_replicas} export_per_rank {self.export_per_rank}" ) - def get_ep_rank_to_expert_id_list_by_layer(self, layer_id): + def get_ep_rank_to_expert_id_list_by_layer( + self, layer_id: int + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ get_ep_rank_to_expert_id_list_by_layer """ @@ -103,7 +104,9 @@ class RedundantExpertManger: self.model_expert_in_rank_num_list[layer_id], \ self.model_tokens_per_expert_stats_list[layer_id] - def get_ep_rank_to_expert_id_list(self, layer_id): + def get_ep_rank_to_expert_id_list( + self, layer_id: int + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ get_ep_rank_to_expert_id_list """ @@ -112,9 +115,12 @@ class RedundantExpertManger: self.model_expert_in_rank_num_list[layer_id], \ self.model_tokens_per_expert_stats_list[layer_id] - def get_expert_tokens_stats(self, - verbose: bool = False, - clear_stat: bool = False): + def get_expert_tokens_stats( + self, + verbose: bool = False, + clear_stat: bool = False + ) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray], + Optional[np.ndarray]]: """ get_per_expert_tokens_stats """ @@ -130,7 +136,7 @@ class RedundantExpertManger: if clear_stat: self.model_tokens_per_expert_stats_list.zero_() - def get_expert_id_to_ep_rank_array(self): + def get_expert_id_to_ep_rank_array(self) -> np.ndarray: """ get_expert_id_to_ep_rank_array """ @@ -140,7 +146,7 @@ class RedundantExpertManger: rank_expert_list: np.ndarray, logical_to_physical_map: np.ndarray, expert_count: np.ndarray, - clear_stat: bool = True): + clear_stat: bool = True) -> None: """ update_expert_rank_table """ diff --git a/fastdeploy/worker/forward_meta.py b/fastdeploy/worker/forward_meta.py index 18149c969..a1007f4e1 100644 --- a/fastdeploy/worker/forward_meta.py +++ b/fastdeploy/worker/forward_meta.py @@ -330,6 +330,8 @@ class ForwardMeta(): decoder_batch_ids: Optional[paddle.Tensor] = None # for attention backend decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + # is_decode_batch or not + is_decode_batch: bool = False @classmethod def init_forward_meta(cls, share_inputs: Dict, @@ -356,6 +358,11 @@ class ForwardMeta(): "decoder_tile_ids_per_batch", None), ) return ret + + def clear_caches(self): + """safe clear caches""" + if self.caches: + del self.caches @dataclass diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 7fb7ab5c7..b211aa9a8 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -20,6 +20,7 @@ from typing import List, Optional import numpy as np import paddle import paddle.nn as nn +from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request @@ -41,13 +42,10 @@ from fastdeploy.model_executor.pre_and_post_process import (post_process, rebuild_padding, step_cuda) from fastdeploy.spec_decode import MTPProposer, NgramProposer -from fastdeploy.utils import get_logger from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput -logger = get_logger("gpu_model_runner", "gpu_model_runner.log") - class GPUModelRunner(ModelRunnerBase): """ """ @@ -593,6 +591,10 @@ class GPUModelRunner(ModelRunnerBase): time_before_load = time.perf_counter() # 1. Load original model self.model = get_model_from_loader(fd_config=self.fd_config) + # 1.1 Load RL dynamic model + if self.fd_config.load_config.dynamic_load_weight: + from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager + self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) # 2. Load lora model @@ -620,6 +622,25 @@ class GPUModelRunner(ModelRunnerBase): # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) + + def clear_cache(self): + """Clear cached data from shared inputs and forward metadata.""" + self.share_inputs.pop("caches", None) + if self.forward_meta is not None: + self.forward_meta.clear_caches() + + def clear_parameters(self, pid): + """"dynamic model loader use to clear parameters use for RL""" + self.dynamic_weight_manager.clear_parameters(pid) + self.clear_cache() + paddle.device.cuda.empty_cache() + self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") + + def update_parameters(self, pid): + """"dynamic model loader use to update parameters use for RL""" + self.dynamic_weight_manager.update_parameters(pid) + self.initialize_kv_cache() + self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def initialize_kv_cache(self) -> None: """ @@ -691,15 +712,14 @@ class GPUModelRunner(ModelRunnerBase): head_dim = self.model_config.head_dim # Get the attention backend - attn_cls = get_attention_backend( - self.parallel_config.attention_backend) + attn_cls = get_attention_backend() attn_backend = attn_cls(self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim) if attn_backend is None: raise NotImplementedError( - f"{ self.parallel_config.attention_backend} attention backend is not support by GPUModelRunner" + "Attention backend which you chose is not support by GPUModelRunner" ) self.attn_backends.append(attn_backend) @@ -735,6 +755,7 @@ class GPUModelRunner(ModelRunnerBase): is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing + self.forward_meta.is_decode_batch = is_decode_batch model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta) @@ -967,6 +988,7 @@ class GPUModelRunner(ModelRunnerBase): is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch + self.forward_meta.is_decode_batch = is_decode_batch model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta) @@ -1124,9 +1146,7 @@ class GPUModelRunner(ModelRunnerBase): batch_size=min(self.parallel_config.max_num_seqs, 3)) # 3. gc - del self.share_inputs["caches"] - if self.forward_meta is not None: - del self.forward_meta.caches + self.clear_cache() if self.speculative_method in ["mtp"]: self.proposer.clear_dummy_input() diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py index 302676140..f48cefe8f 100644 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ b/fastdeploy/worker/vl_gpu_model_runner.py @@ -16,10 +16,12 @@ import json import os import random +import argparse import numpy as np import paddle import paddle.distributed.fleet as fleet +from paddleformers.transformers.model_utils import load_tp_checkpoint from safetensors import safe_open from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer @@ -38,11 +40,13 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope.modeling import \ DFNRopeVisionTransformerPretrainedModel from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ( ScatterOp, VariableResolutionResamplerModel) -from fastdeploy.model_executor.models.utils import load_checkpoint from fastdeploy.platforms import current_platform from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.utils import check_safetensors_model from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase +from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig, + LoadConfig, ModelConfig, MoEConfig, + MoEPhase, ParallelConfig, SpeculativeConfig) if current_platform.is_cuda() and current_platform.available(): from fastdeploy.model_executor.layers.utils import ( @@ -55,8 +59,20 @@ from fastdeploy.model_executor.ops.gpu import (save_output, class GPUVLModelRunner(VLModelRunnerBase): + """ + The GPUVLModelRunner class for vision-language tasks on GPU. + """ - def __init__(self, config, args, nranks, rank): + def __init__( + self, + config: ModelConfig, + args: argparse.Namespace, + nranks: int, + rank: int, + ) -> None: + """ + GPUVLModelRunner init + """ self.nranks = nranks self.rank = rank @@ -104,14 +120,11 @@ class GPUVLModelRunner(VLModelRunnerBase): self.sampler = Sampler() def _reset_paddle_env(self): - #FLAGS_gqa_use_tensorcore - #FLAGS_ffn2_use_hardamard - # gqa .etc paddle Flags set pass - def update_chunked_prefill(self, tasks): + def update_chunked_prefill(self, tasks: list[any]) -> None: """ - 更新chunked prefill相关参数 + update chunked prefill """ if not self.args.enable_chunked_prefill: return @@ -135,7 +148,7 @@ class GPUVLModelRunner(VLModelRunnerBase): "image_features"] = self.extract_vision_features( inputs) else: - # 兼容没有图片和视频的情况 + # Compatible with the situation that lacks images and videos self.share_inputs["image_features"] = None token_chunk_size = inputs["input_ids"].shape[1] @@ -152,7 +165,14 @@ class GPUVLModelRunner(VLModelRunnerBase): task.start_idx += token_chunk_size task.chunk_idx += 1 - def _load_model(self, model_name, dynamic_load_weight): + def _load_model( + self, + model_name: str, + dynamic_load_weight: int = 0, + ) -> None: + """ + Load the model from the given model name. + """ vocab_file_names = [ "tokenizer.model", "spm.model", "ernie_token_100k.model" @@ -261,7 +281,7 @@ class GPUVLModelRunner(VLModelRunnerBase): fd_config.parallel_config.max_model_len = fd_config.model_config.max_seq_len self.fd_config = fd_config - attn_backend_cls = get_attention_backend(self.args.attention_backend) + attn_backend_cls = get_attention_backend() num_heads = self.fd_config.model_config.num_attention_heads // \ self.fd_config.parallel_config.tensor_parallel_degree self.fd_config.model_config.kv_num_heads = int( @@ -275,7 +295,10 @@ class GPUVLModelRunner(VLModelRunnerBase): head_dim=head_dim) self._init_kvcache() - def init_extra_input(self, config, args): + def init_extra_input(self, config: ModelConfig, args: argparse.Namespace) -> None: + """ + Initialize extra input tensors. + """ head_dim = self.model_cfg.head_dim self.share_inputs.update({ "rope_emb": @@ -287,29 +310,31 @@ class GPUVLModelRunner(VLModelRunnerBase): }) self.share_inputs.update({"image_features": None}) self.share_inputs.update({ - "need_think_end": paddle.full(shape=[ - args.max_num_seqs, 1], - fill_value=0, - dtype="int32") + "need_think_end": + paddle.full(shape=[args.max_num_seqs, 1], + fill_value=0, + dtype="int32") }) self.share_inputs.update({ - "enable_thinking": paddle.full(shape=[1], - fill_value=True, - dtype="bool") + "enable_thinking": + paddle.full(shape=[1], fill_value=True, dtype="bool") }) self.share_inputs.update({ - "reasoning_index": paddle.full(shape=[ - args.max_num_seqs, 1], - fill_value=0, - dtype="int32") + "reasoning_index": + paddle.full(shape=[args.max_num_seqs, 1], + fill_value=0, + dtype="int32") }) - def init_rotary_position_embedding(self, max_model_len): + def init_rotary_position_embedding(self, max_model_len: int) -> None: + """ + Init rotary position embedding + """ pass def _init_kvcache(self): """ - 分享不拷贝数据 + Init kv cache """ cache_kvs = {} total_block_num = self.num_gpu_blocks @@ -352,7 +377,7 @@ class GPUVLModelRunner(VLModelRunnerBase): del value paddle.device.cuda.empty_cache() - def clear_parameters(self, pid): + def clear_parameters(self, pid: int) -> None: """ clear_parameters """ if "caches" in self.share_inputs: self.model.clear_parameters(pid) @@ -360,7 +385,7 @@ class GPUVLModelRunner(VLModelRunnerBase): paddle.device.cuda.empty_cache() self.model.log_memory_usage("clear all memory") - def update_parameters(self, pid): + def update_parameters(self, pid: int) -> None: """ update_parameters """ if "caches" not in self.share_inputs: self.model.update_parameters(pid) @@ -368,7 +393,7 @@ class GPUVLModelRunner(VLModelRunnerBase): self.model.log_memory_usage("update all memory") @paddle.no_grad() - def set_state_dict(self, args): + def set_state_dict(self, args: argparse.Namespace) -> None: """set_state_dict""" if not self.is_safetensors_model: rank_model_paths = [] @@ -401,7 +426,7 @@ class GPUVLModelRunner(VLModelRunnerBase): self.model.set_state_dict(state_dict) self.resampler_model.set_state_dict(resampler_state) else: - state_dict = load_checkpoint( + state_dict = load_tp_checkpoint( args.model_name_or_path, Ernie4_5_PretrainedModel, self.model_cfg, @@ -414,10 +439,14 @@ class GPUVLModelRunner(VLModelRunnerBase): self.model.set_state_dict(state_dict) @paddle.no_grad() - def vit_load(self, model_path, tensor_parallel_degree, - tensor_parallel_rank): + def vit_load( + self, + model_path: str, + tensor_parallel_degree: int, + tensor_parallel_rank: int, + ) -> None: """ - vit_load tp参数 + Load vit tp weight """ if tensor_parallel_degree == 1: rank_model_path = os.path.join(model_path, "model_state.pdparams") @@ -430,15 +459,18 @@ class GPUVLModelRunner(VLModelRunnerBase): raise ValueError(f"No such a file {rank_model_path}") @paddle.no_grad() - def inject_pp_vision_model(self, args, cfg): + def inject_pp_vision_model(self, args: argparse.Namespace, cfg: Ernie4_5_VLMoeConfig): """ - 注入vision model参数 + Inject pp vision model """ def set_vision_state_dict(model, - tensor_parallel_degree=8, - tensor_parallel_rank=0, - name=""): + tensor_parallel_degree: int=8, + tensor_parallel_rank: int=0, + name: str=""): + """ + Set vision model weight + """ model_state_dict = model.state_dict() compat_keys = [name + k for k in model_state_dict.keys()] model_files = set() @@ -543,7 +575,7 @@ class GPUVLModelRunner(VLModelRunnerBase): return vision_model, resampler_model @paddle.no_grad() - def extract_vision_features(self, inputs): + def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: """extract_vision_features""" assert inputs["images"] is not None grid_thw = inputs["grid_thw"] @@ -585,7 +617,7 @@ class GPUVLModelRunner(VLModelRunnerBase): return image_features @paddle.no_grad() - def prepare_rope3d(self, position_ids, **kwargs): + def prepare_rope3d(self, position_ids: paddle.Tensor, **kwargs) -> paddle.Tensor: """prepare_rope3d""" prefix_max_position_ids = paddle.max(position_ids) + 1 @@ -608,13 +640,13 @@ class GPUVLModelRunner(VLModelRunnerBase): def prefill_finished(self): """ - 判断是否已经完成了prefill操作 + Verify prefill operation completion """ prefill_statue = (self.share_inputs["seq_lens_this_time"] != 0) & ( self.share_inputs["seq_lens_this_time"] != 1) return not paddle.any(prefill_statue).numpy() - def dy_input_preprocess(self, tasks): + def dy_input_preprocess(self, tasks: list[any]) -> None: """ dynamic insertion """ @@ -662,7 +694,7 @@ class GPUVLModelRunner(VLModelRunnerBase): "image_features"] = self.extract_vision_features( inputs) else: - # 兼容没有图片和视频的情况 + # Compatible with the situation that lacks images and videos self.share_inputs["image_features"] = None if task.multimodal_inputs["position_ids"] is not None: position_ids = paddle.to_tensor( @@ -688,7 +720,7 @@ class GPUVLModelRunner(VLModelRunnerBase): "image_features"] = self.extract_vision_features( inputs) else: - # 兼容没有图片和视频的情况 + # Compatible with the situation that lacks images and videos self.share_inputs["image_features"] = None position_ids = inputs["position_ids"] @@ -702,10 +734,11 @@ class GPUVLModelRunner(VLModelRunnerBase): # force self.share_inputs["enable_thinking"][:] = kwargs["enable_thinking"] - self.share_inputs["need_think_end"][idx:idx + - 1, :] = 1 if kwargs["enable_thinking"] else 0 + self.share_inputs["need_think_end"][ + idx:idx + 1, :] = 1 if kwargs["enable_thinking"] else 0 - self.share_inputs["reasoning_index"][idx:idx + 1, :] = kwargs["reasoning_max_tokens"] + self.share_inputs["reasoning_index"][ + idx:idx + 1, :] = kwargs["reasoning_max_tokens"] self.share_inputs["rope_emb"][idx:idx + 1, :] = self.prepare_rope3d( @@ -737,7 +770,7 @@ class GPUVLModelRunner(VLModelRunnerBase): idx:idx + 1, :encoder_block_num] = np.array(task.block_tables, dtype="int32") - def pre_process(self): + def pre_process(self) -> None: """ pre_process """ @@ -794,7 +827,10 @@ class GPUVLModelRunner(VLModelRunnerBase): eos_token_ids=self.share_inputs["eos_token_id"], ) - def generate(self): + def generate(self) -> None: + """ + generate + """ self.pre_process() hiddden_states = self.model(self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], @@ -815,7 +851,10 @@ class GPUVLModelRunner(VLModelRunnerBase): paddle.distributed.broadcast(next_tokens, 0) self.post_process(next_tokens) - def post_process(self, next_tokens): + def post_process(self, next_tokens: paddle.Tensor) -> None: + """ + post_process + """ if self.share_inputs["enable_thinking"]: exists_think_end = next_tokens == self.model_cfg.think_end_id paddle.assign( @@ -823,37 +862,28 @@ class GPUVLModelRunner(VLModelRunnerBase): exists_think_end, self.share_inputs["need_think_end"] - 1, self.share_inputs["need_think_end"], - ), - self.share_inputs["need_think_end"] - ) + ), self.share_inputs["need_think_end"]) paddle.assign( paddle.where( self.share_inputs["need_think_end"].cast("bool"), self.share_inputs["reasoning_index"] - 1, self.share_inputs["reasoning_index"], - ), - self.share_inputs["reasoning_index"] - ) + ), self.share_inputs["reasoning_index"]) stop_wo_think = ( - ( - next_tokens == self.share_inputs["eos_token_id"] - ) | ( - self.share_inputs["reasoning_index"] == 0 - ) - ) & ( - self.share_inputs["need_think_end"] > 0 - ) - next_tokens = paddle.where(stop_wo_think, self.model_cfg.think_end_id, next_tokens) + (next_tokens == self.share_inputs["eos_token_id"]) | + (self.share_inputs["reasoning_index"] == 0)) & ( + self.share_inputs["need_think_end"] > 0) + next_tokens = paddle.where(stop_wo_think, + self.model_cfg.think_end_id, + next_tokens) paddle.assign( paddle.where( stop_wo_think, self.share_inputs["need_think_end"] - 1, self.share_inputs["need_think_end"], - ), - self.share_inputs["need_think_end"] - ) + ), self.share_inputs["need_think_end"]) paddle.assign( paddle.where( self.share_inputs["stop_flags"], @@ -899,14 +929,13 @@ class GPUVLModelRunner(VLModelRunnerBase): def _cal_theortical_kvcache(self): """ - 计算理论的kvcache大小 + Calculate the size of kvcache for computational theory """ num_layers = self.model_cfg.get("num_layers", None) or self.model_cfg.get( "num_hidden_layers", None) byte_of_cache = 2 - #TODO - # 支持c8 c4 + # support c8 c4 hidden_dim = self.model_cfg.head_dim * self.model_cfg.kv_num_head theoretical_kv_cache_memory = (2 * byte_of_cache * @@ -915,6 +944,9 @@ class GPUVLModelRunner(VLModelRunnerBase): return theoretical_kv_cache_memory def _update_share_input_block_num(self): + """ + Update share_inputs['block_tables'] and share_inputs['free_list'] + """ num_gpu_blocks = self.num_gpu_blocks del self.share_inputs["caches"] @@ -924,7 +956,7 @@ class GPUVLModelRunner(VLModelRunnerBase): self.share_inputs["block_tables"] = paddle.full( [self.args.max_num_seqs, num_gpu_blocks], -1, dtype="int32") - # 初始化free list + # Init free list free_list = list( range(num_gpu_blocks - 1, int(num_gpu_blocks * self.args.kv_cache_ratio) - 1, -1)) @@ -936,7 +968,7 @@ class GPUVLModelRunner(VLModelRunnerBase): paddle.full([1], self.free_list_len, dtype="int32"), }) - def dummy_input(self, num_total_tokens, number_of_tasks): + def dummy_input(self, num_total_tokens: int, number_of_tasks: int) -> None: """ fake input to profile """ @@ -974,7 +1006,7 @@ class GPUVLModelRunner(VLModelRunnerBase): self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \ (idx + 1) * block_num, 1) - def _preprocess_task(self, one): + def _preprocess_task(self, one: dict) -> None: """process batch""" input_ids = one["input_ids"][np.newaxis, :] @@ -1012,13 +1044,13 @@ class GPUVLModelRunner(VLModelRunnerBase): def build_stream_line_model( - model_path, - dtype, - block_size, - max_model_len, - tokenizer, + model_path: str, + dtype: str, + block_size: int, + max_model_len: int, + tokenizer: ErnieBotTokenizer, quantization: str = "None", -): +) -> tuple[FDConfig, paddle.nn.layer]: """ build model """ @@ -1028,9 +1060,6 @@ def build_stream_line_model( from paddleformers.trl import llm_utils from paddleformers.utils.log import logger - from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig, - LoadConfig, ModelConfig, MoEConfig, - MoEPhase, ParallelConfig, SpeculativeConfig) from fastdeploy.model_executor.layers.quantization import \ get_quantization_config from fastdeploy.model_executor.models.model_base import ModelRegistry diff --git a/fastdeploy/worker/vl_model_runner_base.py b/fastdeploy/worker/vl_model_runner_base.py index 29894890f..d6d8cc4f8 100644 --- a/fastdeploy/worker/vl_model_runner_base.py +++ b/fastdeploy/worker/vl_model_runner_base.py @@ -15,10 +15,12 @@ """ from abc import ABC, abstractmethod +import argparse import paddle import paddle.distributed as dist import paddle.distributed.fleet as fleet +from fastdeploy.config import ModelConfig from fastdeploy.utils import get_logger @@ -27,20 +29,20 @@ logger = get_logger("worker", "worker.log") class VLModelRunnerBase(ABC): """ - Initializes the model and sets up necessary parameters. - - Args: - config (Config): The configuration object for the model. - args (Namespace): The arguments passed to the script. - - Returns: - None. - - Raises: - None. + Engine -> (WIP)Executor -> Worker -> VLModelRunnerBase -> Model + VLModelRunnerBase interface abstracts the model execution logic that + contain input preparation, token generation, and tokenprocessing. """ - def __init__(self, config, args): + def __init__( + self, + config: ModelConfig, + args: argparse.Namespace, + ) -> None: + """ + VLModelRunnerBase init + """ + self.share_inputs = {} self.model_cfg = config self.args = args @@ -66,7 +68,7 @@ class VLModelRunnerBase(ABC): f"current_allocated: {curr_alloc:.2f}GB\n" f"current_reserved: {curr_reserved:.2f}GB") - def init_dist_env(self, seed=20): + def init_dist_env(self, seed=20) -> None: """ init distributed env """ @@ -85,7 +87,7 @@ class VLModelRunnerBase(ABC): fleet.init(is_collective=True, strategy=strategy) self.rank = fleet.worker_index() - def _load_model_init_val(self): + def _load_model_init_val(self) -> None: """ initialize model config from config file """ @@ -105,18 +107,10 @@ class VLModelRunnerBase(ABC): self.min_length = _get_attr("min_length", 1) self.max_length = self.args.max_model_len - def _init_share_inputs(self, max_num_seqs): + def _init_share_inputs(self, max_num_seqs: int) -> None: """ - 初始化共享的输入,包括预测和训练。 - 将所有需要的张量都初始化为零或者特定值。 - - Args: - max_num_seqs (int): 最大批次大小,用于初始化张量。 - - Returns: - None. + initialize shared inputs """ - # 统一使用paddle.full创建张量 self._load_model_init_val() int64_config = {"dtype": "int64"} @@ -124,7 +118,6 @@ class VLModelRunnerBase(ABC): float32_config = {"dtype": "float32"} bool_config = {"dtype": "bool"} - # 批量初始化张量 self.share_inputs.update({ "pre_ids": paddle.full([max_num_seqs, self.max_length], -1, **int64_config), @@ -146,7 +139,6 @@ class VLModelRunnerBase(ABC): "presence_score": paddle.full([max_num_seqs, 1], self.presence_score, **float32_config), - # TODO 名称统一 "min_dec_len": paddle.full([max_num_seqs, 1], self.min_length, **int64_config), "max_dec_len": @@ -207,14 +199,12 @@ class VLModelRunnerBase(ABC): paddle.full([max_num_seqs, 1], -1, **int32_config), }) - # 计算block tables相关参数 pre_max_block_num = ( self.args.max_model_len + self.args.block_size - 1) // self.args.block_size + self.args.enc_dec_block_num self.share_inputs["block_tables"] = paddle.full( [max_num_seqs, pre_max_block_num], -1, **int32_config) - # 初始化free list free_list = list( range( self.args.total_block_num - 1, @@ -228,7 +218,6 @@ class VLModelRunnerBase(ABC): paddle.full([1], self.free_list_len, **int32_config), }) - # 初始化stop seqs self.share_inputs.update({ "stop_seqs_len": paddle.full([self.model_cfg.max_stop_seqs_num], 0, **int32_config), @@ -239,9 +228,9 @@ class VLModelRunnerBase(ABC): ], -1, **int64_config), }) - def update_chunked_prefill(self, tasks): + def update_chunked_prefill(self, tasks: list[any]) -> None: """ - 更新chunked prefill相关参数 + update chunked prefill """ if not self.args.enable_chunked_prefill: return @@ -251,58 +240,38 @@ class VLModelRunnerBase(ABC): def prefill_finished(self): """ - 判断是否已经完成了prefill操作 + Verify prefill operation completion """ return True @abstractmethod - def init_rotary_position_embedding(self, max_model_len): + def init_rotary_position_embedding(self, max_model_len: int) -> None: """ - 初始化旋转位置编码,需要重写该方法。 - 参数max_model_len(int):序列的最大长度。 - 返回值(None):无返回值,需要在方法内完成初始化操作。 + Init rotary position embedding """ raise NotImplementedError @abstractmethod - def _load_model(self, model_dir, dynamic_load_weight): + def _load_model( + self, + model_name: str, + dynamic_load_weight: int = 0, + ) -> None: """ - 加载模型,包括模型参数和优化器等。 - 需要子类实现该方法。 - - Args: - model_dir (str): 模型保存的目录路径。 - - Raises: - NotImplementedError: 当前方法未被实现。 - - Returns: - None. + Load the model from the given model name. """ raise NotImplementedError @abstractmethod def _init_kvcache(self): """ - 初始化kv缓存,用于快速查找数据块。 - 该方法需要被子类实现。 - - Args: - max_block_num (int): 最大的数据块数量。 - - Raises: - NotImplementedError: 当该方法未被子类实现时会引发此异常。 + Init kv cache """ raise NotImplementedError @abstractmethod - def dy_input_preprocess(self): + def dy_input_preprocess(self, tasks: list[any]) -> None: """ - 预处理输入数据,用于计算dy。 - 该函数需要在每次forward之前调用,并且只能调用一次。 - 默认实现抛出NotImplementedError。子类可以根据具体的模型实现此功能。 - - Raises: - NotImplementedError: 如果没有实现该方法。 + dynamic insertion """ raise NotImplementedError diff --git a/fastdeploy/worker/vl_worker_process.py b/fastdeploy/worker/vl_worker_process.py index e555c4222..98b47d270 100644 --- a/fastdeploy/worker/vl_worker_process.py +++ b/fastdeploy/worker/vl_worker_process.py @@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet from fastdeploy.engine.config import ModelConfig from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.utils import get_logger, none_or_str +from fastdeploy.worker.worker_process import initialize_fd_config, parse_args logger = get_logger("worker", "worker.log") @@ -35,7 +36,14 @@ class PrefillTracker: Record the prefill time of the request """ - def __init__(self, engine_pid): + def __init__( + self, + engine_pid: int, + ) -> None: + """ + Initialize the PrefillTracker. + """ + super().__init__() self.start_times = defaultdict(float) prefill_time_data = np.zeros([100], dtype=np.float32) self.prefill_time_signal = IPCSignal(name="prefill_time_signal", @@ -46,7 +54,7 @@ class PrefillTracker: self.current_index = 0 self.executor = ThreadPoolExecutor(max_workers=1) - def start_prefill(self, task_idx): + def start_prefill(self, task_idx: int): """ Record the start time of the prefill process for a given task index. @@ -55,7 +63,7 @@ class PrefillTracker: """ self.start_times[task_idx] = time.time() - def end_prefill(self, task_idx): + def end_prefill(self, task_idx: int): """ Record the end time of the prefill process for a given task index and asynchronously submit the duration for metric recording. @@ -69,7 +77,7 @@ class PrefillTracker: self.executor.submit(self._record_metrics, duration) del self.start_times[task_idx] - def _record_metrics(self, duration): + def _record_metrics(self, duration: float): """ Internal method to record the prefill duration into the signal buffer. Logs the duration and updates a circular buffer of timing metrics. @@ -89,19 +97,19 @@ class PrefillTracker: class Worker: + """ + Engine -> (WIP)Executor -> Worker -> ModelRunner -> Model + Worker interface that allows inference framwork to cleanly separate implementations for different harware. + """ - def __init__(self, args): + def __init__( + self, + args, + ) -> None: """ - Args: - args (ArgumentParser): 命令行参数,包含模型名称、端口号等信息。 - - Returns: - None, 无返回值,初始化完成后会将相关参数和对象保存到类属性中。 - - Raises: - None, 没有异常抛出。 + Initialize the Worker. """ - + super().__init__() self.args = args self.MAX_INFER_SEED = 9223372036854775806 paddle.set_default_dtype(args.dtype) @@ -123,7 +131,7 @@ class Worker: rank=self.rank) self.prefill_tracker = PrefillTracker(args.engine_pid) - # TODO 多机 + # Only applicable for standalone (single-machine) inference address = ('0.0.0.0', self.args.engine_worker_queue_port) self.engine_worker_queue = EngineWorkerQueue( address=address, @@ -154,7 +162,10 @@ class Worker: self.rank = fleet.worker_index() def init_health(self): - # worker_ready_signal 用于engine感知各worker进程是否Ready + """ + init health signals + """ + # To perceive whether each worker process is ready worker_ready_signal_data = np.zeros(shape=[self.nranks], dtype=np.int32) self.worker_ready_signal = IPCSignal(name="worker_ready_signal", @@ -164,7 +175,7 @@ class Worker: create=False) self.worker_ready_signal.value[self.rank] = 1 - # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 + # To monitor the liveness of worker processes and record each step's timestamp worker_healthy_live_recorded_time_array = np.zeros(shape=[self.nranks], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( @@ -175,7 +186,7 @@ class Worker: create=False) self.worker_healthy_live_signal.value[self.rank] = int(time.time()) - # exist_task_signal 用于各worker进程感知是否有新Task需要处理 + # To perceive whether there is a new task to be processed exist_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_task_signal = IPCSignal(name="exist_task_signal", array=exist_task_signal_data, @@ -183,7 +194,7 @@ class Worker: suffix=self.args.engine_pid, create=False) - # exist_swapped_task_signal 用于engine感知worker中是否存在swapped task + # To detect whether there are swapped tasks in the worker exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_swapped_task_signal = IPCSignal( name="exist_swapped_task_signal", @@ -192,7 +203,6 @@ class Worker: suffix=self.args.engine_pid, create=False) - # model_weights_status 用于engine感知各worker中模型权重状态 model_weights_status = np.zeros([1], dtype=np.int32) self.model_weights_status_signal = IPCSignal( name="model_weights_status", @@ -309,17 +319,7 @@ class Worker: def run(self): """ - 运行函数,不断地从队列中获取任务并进行推理。 - 当队列为空或者所有节点都处于等待状态时,将会休眠一段时间再次尝试获取任务。 - - Args: - None. - - Returns: - None. - - Raises: - None. + run function, continuously get tasks and do inference. """ infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1], fill_value=4, @@ -526,153 +526,6 @@ class Worker: break -def parse_args(): - """ - parse args from command line - """ - parser = argparse.ArgumentParser("FastDeploy LLM Inference") - parser.add_argument("-m", - "--model_name_or_path", - type=str, - default="./output", - help="model dir") - parser.add_argument("-mbs", - "--max_num_seqs", - type=int, - default=34, - help="max batch size") - parser.add_argument("--total_block_num", type=int, default=2000) - parser.add_argument("--block_size", type=int, default=64) - parser.add_argument("--engine_worker_queue_port", type=int, default=9923) - parser.add_argument("--max_model_len", - type=int, - default=3072, - help="max model len") - parser.add_argument("--device_ids", - type=str, - default="0", - help="cuda visible devices") - parser.add_argument("--dtype", - type=str, - default="bfloat16", - help="input dtype") - parser.add_argument("--enc_dec_block_num", - type=int, - default=1, - help="encoder's decoder num") - parser.add_argument("--kv_cache_ratio", - type=float, - default=0.7, - help="kv cache ratio for input") - parser.add_argument("--first_token_id", - type=int, - default=1, - help="first token id") - parser.add_argument("--gpu_memory_utilization", - type=float, - default=0.9, - help="gpu memory utilization") - parser.add_argument("--engine_pid", - type=int, - default=None, - help="Process ID of engine") - parser.add_argument("--do_profile", - action='store_true', - help="do profile or not") - parser.add_argument("--dynamic_load_weight", - action='store_true', - help="dynamic load weight or not") - parser.add_argument("--pad_token_id", - type=int, - default=-1, - help="pad token id") - parser.add_argument("--eos_tokens_lens", - type=int, - default=2, - help="eos token lens") - parser.add_argument("--enable_chunked_prefill", - action='store_true', - help="enable chunked prefill") - parser.add_argument( - "--speculative_method", - default=None, - type=none_or_str, - choices=[None, "ngram", "mtp"], - ) - parser.add_argument( - "--speculative_max_draft_token_num", - default=1, - type=int, - ) - parser.add_argument( - "--speculative_model_name_or_path", - default="", - type=str, - ) - parser.add_argument( - "--speculative_model_quantization", - default="", - type=str, - ) - parser.add_argument( - "--attention_backend", - default="APPEND_ATTN", - type=str, - choices=[ - "APPEND_ATTN", - ], - ) - parser.add_argument("--max_num_batched_tokens", - type=int, - default=2048, - help="max num batched tokens") - parser.add_argument("--enable_prefix_caching", - action='store_true', - help="enable prefix cache") - parser.add_argument("--splitwise_role", - type=str, - default="mixed", - help="splitwise role") - parser.add_argument("--ori_vocab_size", type=int, default=None) - parser.add_argument("--tensor_parallel_size", - type=int, - default=1, - help="tensor parallel size") - parser.add_argument("--expert_parallel_size", - type=int, - default=1, - help="expert parallel size") - parser.add_argument("--quantization", - type=str, - default="", - help="Quantization name for the model, currentlly support " \ - "'wint4', 'wint8'," \ - "default is None. The priority of this configuration "\ - "is lower than that of the config file. " \ - "More complex quantization methods need to be configured via the config file.") - parser.add_argument("--enable_static_graph_inference", - action='store_true', - help="Whether to use static mode; if enabled, " \ - "'paddle.to_static' will be used to convert dynamic to static.") - parser.add_argument("--use_cudagraph", - action='store_true', - help="Flags to enable cuda graph.") - parser.add_argument("--max_capture_batch_size", - type=int, - default=64, - help="Maximum of Batch Size for Warm Up.") - parser.add_argument("--guided_decoding_backend", - type=str, - default="off", - help="guided decoding backend") - parser.add_argument("--disable_any_whitespace", - action='store_false', - help="Disable any whitespace for guided decoding.") - - args = parser.parse_args() - return args - - def main(): """ start worker diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f7a44d6e3..2b2642c2a 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -14,7 +14,6 @@ # limitations under the License. """ import argparse -import json import time from typing import List @@ -23,6 +22,7 @@ import paddle import paddle.distributed as dist import paddle.distributed.fleet as fleet +from fastdeploy import envs from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig, GraphOptimizationConfig, LoadConfig, ModelConfig, MoEConfig, MoEPhase, @@ -61,14 +61,21 @@ class PaddleDisWorkerProc(): def __init__( self, fd_config: FDConfig, - ): + ) -> None: + """ + Initialize a distributed worker and task queue for single-node multi-GPU setup. + Args: + fd_config (FDConfig): Arguments related to inference, containing + attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim, + num_attention_heads, and ffn_hidden_size. + """ self.fd_config = fd_config self.parallel_config = fd_config.parallel_config # Initialize distributed enviroment - (self.rank, self.local_rank) = self.init_distributed_enviroment() + (self.ranks, self.local_rank) = self.init_distributed_enviroment() - assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.rank + assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.ranks self.fd_config.parallel_config.tensor_parallel_rank = \ self.local_rank % self.parallel_config.tensor_parallel_degree @@ -81,8 +88,6 @@ class PaddleDisWorkerProc(): self.fd_config.moe_config.num_experts_start_offset = \ self.fd_config.parallel_config.expert_parallel_rank * self.fd_config.moe_config.num_experts_per_rank - self.fd_config.parallel_config.column_cut = False - # For auto TP split self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_degree self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank @@ -95,7 +100,7 @@ class PaddleDisWorkerProc(): # TODO(gongshaotian): Use worker factory to get worker self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, - rank=self.rank) + rank=self.ranks) # Initialize task queue task_address = ('0.0.0.0', @@ -109,7 +114,7 @@ class PaddleDisWorkerProc(): local_data_parallel_id=self.fd_config.parallel_config. expert_parallel_rank) - def init_health_status(self): + def init_health_status(self) -> None: """ Initialize the health status of the worker. Worker Status: @@ -134,7 +139,7 @@ class PaddleDisWorkerProc(): self.worker_ready_signal.value[self.local_rank % 8] = 1 # init worker_healthy_live_signal - workers_alive = np.zeros(shape=[self.rank], dtype=np.int32) + workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=workers_alive, @@ -183,16 +188,7 @@ class PaddleDisWorkerProc(): suffix=self.parallel_config.engine_pid, create=False) - # init model_weights_status - workers_model_weights = np.zeros(shape=[1], dtype=np.int32) - self.model_weights_status = IPCSignal( - name="model_weights_status", - array=workers_model_weights, - dtype=np.int32, - suffix=self.parallel_config.engine_pid, - create=False) - - def event_loop_ep(self): + def event_loop_ep(self) -> None: """ Tmp loop function for ep utill DP is supported """ @@ -217,7 +213,7 @@ class PaddleDisWorkerProc(): # These generated tokens can be obtained through get_output op. self.worker.execute_model() - def event_loop_normal(self): + def event_loop_normal(self) -> None: """ Main event loop for Paddle Distrubuted Workers. TODO(gongshaotian): support remote calling of functions that control worker. """ @@ -225,6 +221,12 @@ class PaddleDisWorkerProc(): self.nnode = 1 req_ids = [] while True: + if self.local_rank == 0: + if self.model_weights_status.value[0] != 0: + self.exist_task_signal.value[0] = 2 + else: + self.exist_task_signal.value[0] = 0 + if self.parallel_config.tensor_parallel_degree > 1: # Synchronize before updating weights paddle.distributed.barrier() @@ -234,7 +236,7 @@ class PaddleDisWorkerProc(): time.time()) # The first worker detects whether there are tasks in the task queue - mp_num_per_node = self.rank / self.nnode + mp_num_per_node = self.ranks / self.nnode if self.local_rank % mp_num_per_node == 0: if self.task_queue.num_tasks() > 0: if self.nnode > 1: @@ -249,6 +251,14 @@ class PaddleDisWorkerProc(): # TODO(@wufeisheng): Split TP group and EP group paddle.distributed.barrier() + if self.fd_config.load_config.dynamic_load_weight: + if self.exist_task_signal.value[0] == 2: + from fastdeploy.rl.dynamic_weight_manager import \ + DynamicWeightManager + DynamicWeightManager.check_model_weights_status( + self.model_weights_status, self.worker.model_runner, + self.parallel_config.engine_pid) + if self.exist_task_signal.value[ self.fd_config.parallel_config.expert_parallel_rank] == 1 or \ self.task_queue.read_finish_flag.get() == 1: @@ -275,7 +285,7 @@ class PaddleDisWorkerProc(): self.worker.preprocess_new_task(req_dicts) if not self.worker.model_runner.not_need_stop(): - if self.rank > 1: + if self.ranks > 1: paddle.distributed.barrier() time.sleep(0.001) @@ -288,15 +298,15 @@ class PaddleDisWorkerProc(): self.exist_prefill_task_signal.value[ 0] = self.worker.prefill_finished() - def init_distributed_enviroment(self, seed=20) -> List[int]: + def init_distributed_enviroment(self, seed: int = 20) -> List[int]: """ Initialize Paddle Fleet and get rank of worker """ # Global rank - self.rank = dist.get_world_size() + self.ranks = dist.get_world_size() dist_strategy = fleet.DistributedStrategy() dist_strategy.hybrid_configs = { "dp_degree": 1, - "mp_degree": self.rank, + "mp_degree": self.ranks, "pp_degree": 1, "sharding_degree": 1, } @@ -308,10 +318,19 @@ class PaddleDisWorkerProc(): # Local rank self.local_rank = fleet.worker_index() - return self.rank, self.local_rank + return self.ranks, self.local_rank - def determine_num_available_blocks(self): - """ + def determine_num_available_blocks(self) -> None: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ if self.fd_config.parallel_config.do_profile: # 1. Get available memory(bytes) @@ -343,7 +362,8 @@ class PaddleDisWorkerProc(): ) # 3. Send IPCSignal - get_profile_block_num = np.zeros(shape=[self.rank], dtype=np.int32) + get_profile_block_num = np.zeros(shape=[self.ranks], + dtype=np.int32) self.get_profile_block_num_signal = IPCSignal( name="get_profile_block_num", array=get_profile_block_num, @@ -366,12 +386,12 @@ class PaddleDisWorkerProc(): # 4. Updata share inputs self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global) - def init_device(self): - """ """ + def init_device(self) -> None: + """ Initialize device and Construct model runner """ self.worker.init_device() - def load_model(self): - """ """ + def load_model(self) -> None: + """ Load weights and create model """ self.worker.load_model() @@ -428,9 +448,6 @@ def parse_args(): parser.add_argument("--do_profile", action='store_true', help="do profile or not") - parser.add_argument("--dynamic_load_weight", - action='store_true', - help="dynamic load weight or not") parser.add_argument("--pad_token_id", type=int, default=-1, @@ -467,14 +484,6 @@ def parse_args(): default="WINT8", type=str, ) - parser.add_argument( - "--attention_backend", - default="APPEND_ATTN", - type=str, - choices=[ - "APPEND_ATTN", - ], - ) parser.add_argument("--max_num_batched_tokens", type=int, default=2048, @@ -527,11 +536,26 @@ def parse_args(): parser.add_argument("--disable_any_whitespace", action='store_false', help="Disable any whitespace for guided decoding.") + parser.add_argument("--dynamic_load_weight", + action='store_true', + help="Enable dynamic weight loading strategy") + parser.add_argument( + "--load_strategy", + type=str, + choices=['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta', 'normal'], + default='meta', + help="Weight loading method when dynamic loading is enabled: " + "'ipc': real-time IPC streaming with automatic resharding, " + "'ipc_no_reshard': IPC streaming without weight processing, " + "'ipc_snapshot': load from disk snapshot of IPC weights, " + "'meta': provide RL traing worker, no_weights_load" + "'normal':normal load weight") + args = parser.parse_args() return args -def initialize_fd_config(args) -> FDConfig: +def initialize_fd_config(args: argparse.Namespace) -> FDConfig: """Initialize FDConfig TODO(gongshaotian): Unified all configs to FDConfig """ @@ -554,7 +578,7 @@ def initialize_fd_config(args) -> FDConfig: # model_config = ModelConfig() decoding_config = DecodingConfig() - decoding_config = MoEConfig() + speculative_config = SpeculativeConfig() parallel_config = ParallelConfig() load_config = LoadConfig() @@ -592,7 +616,6 @@ def initialize_fd_config(args) -> FDConfig: parallel_config.pad_token_id = args.pad_token_id parallel_config.eos_tokens_lens = args.eos_tokens_lens parallel_config.enable_chunked_prefill = args.enable_chunked_prefill - parallel_config.attention_backend = args.attention_backend parallel_config.max_num_batched_tokens = args.max_num_batched_tokens parallel_config.enable_prefix_caching = args.enable_prefix_caching @@ -600,6 +623,7 @@ def initialize_fd_config(args) -> FDConfig: parallel_config.tensor_parallel_degree = args.tensor_parallel_size parallel_config.expert_parallel_degree = args.expert_parallel_size parallel_config.splitwise_role = args.splitwise_role + load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 parallel_config.guided_decoding_backend = args.guided_decoding_backend parallel_config.disable_any_whitespace = args.disable_any_whitespace @@ -659,19 +683,20 @@ def initialize_fd_config(args) -> FDConfig: moe_config.num_max_dispatch_tokens_per_rank = config.get( "num_max_dispatch_tokens_per_rank", 256) + moe_config.moe_use_aux_free = config.get("moe_use_aux_free", False) model_config.ori_vocab_size = config.get("vocab_size", -1) if "Ernie4_5_ForCausalLM" in config.get("architectures"): model_config.ori_vocab_size = args.ori_vocab_size - quantization_config = config.get("quantization_config", None) + if "DeepseekV3ForCausalLM" in config.get("architectures"): + from paddleformers.transformers import AutoConfig + model_config.deepseekv3 = AutoConfig.from_pretrained( + args.model_name_or_path) - # Note(@wufeisheng): The `is_quantized` flag should be explicitly set to `true` - # when the weights are actually quantized offline. For backward compatibility - # with preview logic: - # - If `quantization_config` is provided but `is_quantized` is not explicitly set, - # the value of `is_quantized` will be determined by whether `kv_cache_quant_type` - # has been configured. + #TODO(@yuanrisheng): kv_cache quant config can only be + # stored in model config file, which should be unified + quantization_config = config.get("quantization_config", None) if not model_config.is_quantized: if quantization_config is not None: if "kv_cache_quant_type" not in quantization_config: @@ -689,9 +714,14 @@ def initialize_fd_config(args) -> FDConfig: elif args.quantization != "None": quantization_config = {} quant_config_name = args.quantization - if use_moe and quant_config_name == "wint4": + quantization_config["quantization"] = quant_config_name + # use some trick code for ernie model and will unify it in future. + is_ernie = "Ernie4_5_ForCausalLM" in config.get("architectures") or \ + "Ernie4_5_MoeForCausalLM" in config.get("architectures") + if use_moe and quant_config_name == "wint4" and is_ernie: quantization_config["dense_quant_type"] = "wint8" quantization_config["moe_quant_type"] = "wint4" + quantization_config["quantization"] = "mix_quant" quant_config_name = "mix_quant" else: quant_config_name = None @@ -706,20 +736,26 @@ def initialize_fd_config(args) -> FDConfig: if quant_config is not None: if model_config.is_quantized: logger.info( - "=====The currently loaded model is an offline quantized model=====" + "Model Status: Offline Quantized (pre-quantized weights loaded)" ) else: - logger.info("=====The currently loaded model is the original model\ - The model will be quantized online=====") - logger.info(f"{json.dumps(quantization_config, indent=2)}") + logger.info( + "Model Status: Original (will apply online quantization)") + + logger.info(f"Quantization Method: {args.quantization or 'None'}") else: logger.info( "No quantization config found and use original weight and act dtype." ) - logger.info("============================================") model_config.architectures = config.get("architectures") + logger.info("===========load_config==============") + load_config.dynamic_load_weight = args.dynamic_load_weight + load_config.load_strategy = args.load_strategy + logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") + logger.info(f"- Load strategy: {load_config.load_strategy}") + fd_config = FDConfig(model_config=model_config, parallel_config=parallel_config, speculative_config=speculative_config, @@ -733,7 +769,7 @@ def initialize_fd_config(args) -> FDConfig: return fd_config -def run_worker_proc(): +def run_worker_proc() -> None: """ start worker process """ diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index bf2253cb8..8df9357b5 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -583,15 +583,14 @@ class XPUModelRunner(ModelRunnerBase): head_dim = self.model_config.head_dim # Get the attention backend - attn_cls = get_attention_backend( - self.parallel_config.attention_backend) + attn_cls = get_attention_backend() attn_backend = attn_cls(self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim) if attn_backend is None: raise NotImplementedError( - f"{ self.parallel_config.attention_backend} attention backend is not support by XPUModelRunner" + "Attention backend which you chose is not support by GPUModelRunner" ) self.attn_backends.append(attn_backend) diff --git a/requirements.txt b/requirements.txt index ef3573857..72f79add4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ moviepy triton==3.3 use-triton-in-paddle crcmod +fastsafetensors==0.1.14