diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 31c7bc061..2b3110f9d 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -56,15 +56,14 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; - int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x; - int64_t all_warp_num = gridDim.x * blockDim.x; + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; int64_t all_head_dim = elem_cnt / head_size; const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; - // const int64_t offset = 2 * hidden_size; const int half_head_size = head_size / 2; for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) { - int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize; + int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize; const int ori_bi = linear_index / hidden_size; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v @@ -122,13 +121,13 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( float row_inv_var = Rsqrt(row_variance + rms_norm_eps); LoadT q_norm_vec, k_norm_vec; if (hi < num_heads) { // q - Load(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec); + Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * static_cast(q_norm_vec[i])); } } else { // k - Load(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec); + Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); for (int i = 0; i < VecSize; i++) { out_vec[i] = static_cast(static_cast(out_vec[i]) * row_inv_var * static_cast(k_norm_vec[i])); } diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index 77cdfa300..8561460d1 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -45,7 +45,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, const uint32_t elem_nums = use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2 : bsz * (num_heads + 2 * kv_num_heads) * dim_head; - assert(dim_head == 128 && "dim_head must be 128"); constexpr int HEAD_DIM = 128; constexpr int PackSize = HEAD_DIM / kWarpSize; @@ -53,7 +52,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); - dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1); + dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); append_decode_cache_T_rope_qk_norm_kernel <<>>(reinterpret_cast(qkv), key_cache, diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 5215b933a..74169349e 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -432,13 +432,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( LoadT src_vec; LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; - int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x; - int64_t all_warp_num = gridDim.x * blockDim.x; + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + kv_num_head) * last_dim; const int all_head_num = elem_cnt / last_dim; for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) { - int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize; + int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize; const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; @@ -478,13 +478,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( float row_inv_var = Rsqrt(row_variance + rms_norm_eps); LoadT q_norm_vec, k_norm_vec; if (hi < q_num_head) { - Load(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec); + Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { src_vec[i] = static_cast(static_cast(src_vec[i]) * row_inv_var * static_cast(q_norm_vec[i])); } } else { - Load(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec); + Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); for (int i = 0; i < VecSize; i++) { src_vec[i] = static_cast(static_cast(src_vec[i]) * row_inv_var * static_cast(k_norm_vec[i])); } @@ -1690,13 +1690,13 @@ void gqa_rotary_qk_norm_variable( const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); - dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1); + dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1); const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; GQAVariableLengthRotaryQKNormKernel - <<>>( + <<>>( reinterpret_cast(qkv_input), cos_emb, sin_emb, diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 9efbab433..125cb246c 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -430,6 +430,9 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (group_size == 12) { \ constexpr size_t GROUP_SIZE = 12; \ __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ } else if (group_size == 16) { \ constexpr size_t GROUP_SIZE = 16; \ __VA_ARGS__ \ diff --git a/custom_ops/gpu_ops/noaux_tc.cu b/custom_ops/gpu_ops/noaux_tc.cu index 7b6d432c8..19a9e380f 100644 --- a/custom_ops/gpu_ops/noaux_tc.cu +++ b/custom_ops/gpu_ops/noaux_tc.cu @@ -28,19 +28,20 @@ std::vector NoauxTc(paddle::Tensor& scores, int topk, float routed_scaling_factor) { auto input_shape = scores_with_bias.shape(); + PD_CHECK(input_shape.size() == 2); int64_t num_tokens = input_shape[0]; int64_t num_experts = input_shape[1]; auto input_type = scores_with_bias.dtype(); auto place = scores_with_bias.place(); auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place); auto topk_values = paddle::empty({num_tokens, topk}, input_type, place); - auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place); + auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place); auto stream = scores_with_bias.stream(); - invokeNoAuxTc(reinterpret_cast(scores.data()), + invokeNoAuxTc(reinterpret_cast(scores.data()), reinterpret_cast(group_scores.data()), reinterpret_cast(topk_values.data()), - reinterpret_cast(topk_indices.data()), + reinterpret_cast(topk_indices.data()), reinterpret_cast(scores_with_bias.data()), num_tokens, num_experts, @@ -56,7 +57,7 @@ std::vector NoauxTc(paddle::Tensor& scores, std::vector NoauxTcInferDtype( const paddle::DataType& scores_dtype, const paddle::DataType& scores_with_bias_dtype) { - return {scores_dtype, scores_dtype, paddle::DataType::INT32}; + return {scores_dtype, scores_dtype, paddle::DataType::INT64}; } std::vector> NoauxTcInferShape( @@ -71,7 +72,7 @@ std::vector> NoauxTcInferShape( PD_BUILD_STATIC_OP(noaux_tc) .Inputs({"scores", "scores_with_bias"}) - .Outputs({"output_tensor"}) + .Outputs({"output_tensor", "topk_values", "topk_indices"}) .Attrs({"n_group: int", "topk_group: int", "topk:int", diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index cb717f963..02ccead7f 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -49,7 +49,7 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ scores = paddle.nn.functional.sigmoid(gating_output) - scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 73306680b..6a86589bf 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -312,13 +312,26 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): below is TP compute method. """ gate_out = gate(x.cast("float32")) - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - layer.top_k, - True, # apply_norm_weight - False, - ) + + if layer.topk_method == "noaux_tc": + from .ep import get_moe_scores + + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 310f4d3df..c46bbac72 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -285,7 +285,7 @@ class FusedMoE(nn.Layer): dtype="float32", ) up_gate_proj_output_dim = self.moe_intermediate_size * 2 - if self.moe_quant_type in ["fp8", "wint8"]: + if self.moe_quant_type in ["block_wise_fp8", "wint8"]: up_gate_proj_weight_shape = [ self.num_local_experts, up_gate_proj_output_dim, @@ -309,9 +309,10 @@ class FusedMoE(nn.Layer): ] # Create parameters - if self.moe_quant_type == "fp8": + if self.moe_quant_type == "block_wise_fp8": # (TODO:gaoziyuan) - pass + self.weight_dtype = "float8_e4m3fn" + self.init_block_wise_fp8_scale() elif self.moe_quant_type == "wint8": self.weight_dtype = "int8" self.init_weight_only_scale() @@ -342,6 +343,21 @@ class FusedMoE(nn.Layer): dtype=self._dtype, ) + def init_block_wise_fp8_scale(self): + """ + Initialize the weight scale. + """ + self.up_gate_proj_weight_scale = self.create_parameter( + shape=[self.num_local_experts, self.moe_intermediate_size * 2 // 128, self.hidden_size // 128], + dtype="float32", + is_bias=False, + ) + self.down_proj_weight_scale = self.create_parameter( + shape=[self.num_local_experts, self.hidden_size // 128, self.moe_intermediate_size // 128], + dtype="float32", + is_bias=False, + ) + def load_experts_weight( self, state_dict: dict,