diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 875cccab8..e27861d2f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -116,11 +116,11 @@ PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder, paddle::Tensor FusedExpertMoeFunc( const paddle::Tensor &input, const paddle::Tensor &gate_weight, - const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight, - const paddle::optional &ffn1_bias, - const paddle::optional &ffn1_scale, - const paddle::optional &ffn2_bias, - const paddle::optional &ffn2_scale, + const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight, + const paddle::optional &up_gate_proj_bias, + const paddle::optional &up_gate_proj_scale, + const paddle::optional &down_proj_bias, + const paddle::optional &down_proj_scale, const std::string &quant_method, const int moe_topk, const bool norm_topk_prob, const bool group_moe); @@ -149,7 +149,7 @@ MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits, std::vector EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights, - const paddle::optional &ffn1_in_scale, + const paddle::optional &up_gate_proj_in_scale, const std::vector &token_nums_per_expert, const int token_nums_this_rank, const std::string &moe_quant_type); @@ -173,7 +173,7 @@ std::vector EPMoeExpertCombine( const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float, const paddle::Tensor &permute_indices_per_token, const paddle::Tensor &top_k_indices, - const paddle::optional &ffn2_bias, + const paddle::optional &down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor); std::vector> GetExpertTokenNum(const paddle::Tensor &topk_ids, @@ -182,35 +182,35 @@ std::vector> GetExpertTokenNum(const paddle::Tensor &topk_ids, paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn2_in_scale, + const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency); paddle::Tensor MoeExpertFFNWint2Func( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn1_local_scale, - const paddle::optional& ffn1_code_scale, - const paddle::optional& ffn1_code_zp, - const paddle::optional& ffn2_local_scale, - const paddle::optional& ffn2_code_scale, - const paddle::optional& ffn2_code_zp, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& up_gate_proj_local_scale, + const paddle::optional& up_gate_proj_code_scale, + const paddle::optional& up_gate_proj_code_zp, + const paddle::optional& down_proj_local_scale, + const paddle::optional& down_proj_code_scale, + const paddle::optional& down_proj_code_zp, const bool used_in_ep_low_latency); paddle::Tensor MoeExpertReduceFunc( const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight, const paddle::Tensor &permute_indices_per_token, const paddle::Tensor &top_k_indices, - const paddle::optional &ffn2_bias, + const paddle::optional &down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor); void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor, @@ -816,7 +816,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * ep_moe_dispatch */ m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"), - py::arg("topk_ids"), py::arg("topk_weights"), py::arg("ffn1_in_scale"), + py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"), py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"), py::arg("moe_quant_type"), "ep moe export dispatch function"); @@ -824,7 +824,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"), py::arg("expert_scales_float"), py::arg("permute_indices_per_token"), - py::arg("top_k_indices"), py::arg("ffn2_bias"), + py::arg("top_k_indices"), py::arg("down_proj_bias"), py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"), "ep moe export combine function"); @@ -866,7 +866,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { */ m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"), py::arg("top_k_weight"), py::arg("permute_indices_per_token"), - py::arg("top_k_indices"), py::arg("ffn2_bias"), + py::arg("top_k_indices"), py::arg("down_proj_bias"), py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"), "moe export reduce function"); diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh index bfd08d5e6..eb3be5fa5 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -ffn1_n=7168 -ffn1_k=8192 +up_gate_proj_n=7168 +up_gate_proj_k=8192 -ffn2_n=8192 -ffn2_k=3584 -rm -rf ffn1_7168_8192.log -rm -rf ffn2_8192_3584.log +down_proj_n=8192 +down_proj_k=3584 +rm -rf up_gate_proj_7168_8192.log +rm -rf down_proj_8192_3584.log num_experts=8 for tokens_per_expert in 12 do wait -CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 1 0 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 & -# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 1 0 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 & +CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 & +# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 & done wait echo "#### finish ####" diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index 09e006cdc..60ae7d1fc 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -161,7 +161,7 @@ __global__ void combine_prmt_back_kernel( expanded_permuted_rows + expanded_permuted_row * cols; // prmt后的位置对应的值 Load(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, &load_vec); const int expert_idx = expert_for_source_row[k_offset]; // 当前位置对应的专家 - const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的ffn2的bias + const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的down_proj的bias if (bias_ptr) { Load(bias_ptr + tid * VEC_SIZE, &bias_vec); #pragma unroll @@ -188,7 +188,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out, const paddle::Tensor& expert_scales_float, const paddle::Tensor& permute_indices_per_token, const paddle::Tensor& top_k_indices, - const paddle::optional& ffn2_bias, + const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor, const int num_rows, @@ -206,7 +206,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out, combine_prmt_back_kernel<<>>( ffn_out.data(), output->data(), - ffn2_bias ? ffn2_bias->data() : nullptr, + down_proj_bias ? down_proj_bias->data() : nullptr, expert_scales_float.data(), permute_indices_per_token.data(), top_k_indices.data(), @@ -223,7 +223,7 @@ std::vector EPMoeExpertCombine( const paddle::Tensor& expert_scales_float, // dst_weights const paddle::Tensor& permute_indices_per_token, // permute_indices_per_token const paddle::Tensor& top_k_indices, // dst_indices - const paddle::optional& ffn2_bias, + const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { @@ -242,7 +242,7 @@ std::vector EPMoeExpertCombine( expert_scales_float, permute_indices_per_token, top_k_indices, - ffn2_bias, + down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, @@ -255,7 +255,7 @@ std::vector EPMoeExpertCombine( expert_scales_float, permute_indices_per_token, top_k_indices, - ffn2_bias, + down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, @@ -274,7 +274,7 @@ __global__ void permute_x_kernel(const T *src_x, const int64_t *topk_idx, const float *topk_weights, const int *token_nums_per_expert, - const float *ffn1_in_scale, + const float *up_gate_proj_in_scale, const int moe_topk, const int num_rows, const int token_nums_this_rank, @@ -327,9 +327,9 @@ __global__ void permute_x_kernel(const T *src_x, // cp x for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { Load(&src_x[s_token_idx * hidden_size + v_id * vec_size], &src_vec); - if (ffn1_in_scale) { + if (up_gate_proj_in_scale) { for (int i = 0; i < vec_size; i++) { - float quant_value = max_bound * ffn1_in_scale[expert_now] * static_cast(src_vec[i]); + float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast(src_vec[i]); if (RoundType == 0) { res_vec[i] = static_cast(ClipFunc(rint(quant_value), min_bound, max_bound)); } else { @@ -353,7 +353,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, const paddle::Tensor& topk_ids, const paddle::Tensor& topk_weights, const paddle::Tensor& token_nums_per_expert, - const paddle::optional& ffn1_in_scale, + const paddle::optional& up_gate_proj_in_scale, const std::string& moe_quant_type, const int moe_topk, const int num_rows, @@ -383,7 +383,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, topk_ids.data(), topk_weights.data(), token_nums_per_expert.data(), - ffn1_in_scale ? ffn1_in_scale.get().data() : nullptr, + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, moe_topk, num_rows, token_nums_this_rank, @@ -404,7 +404,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, topk_ids.data(), topk_weights.data(), token_nums_per_expert.data(), - ffn1_in_scale ? ffn1_in_scale.get().data() : nullptr, + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, moe_topk, num_rows, token_nums_this_rank, @@ -427,7 +427,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, topk_ids.data(), topk_weights.data(), token_nums_per_expert.data(), - ffn1_in_scale ? ffn1_in_scale.get().data() : nullptr, + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, moe_topk, num_rows, token_nums_this_rank, @@ -448,7 +448,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, topk_ids.data(), topk_weights.data(), token_nums_per_expert.data(), - ffn1_in_scale ? ffn1_in_scale.get().data() : nullptr, + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, moe_topk, num_rows, token_nums_this_rank, @@ -472,7 +472,7 @@ std::vector EPMoeExpertDispatch( const paddle::Tensor& input, const paddle::Tensor& topk_ids, const paddle::Tensor& topk_weights, - const paddle::optional& ffn1_in_scale, + const paddle::optional& up_gate_proj_in_scale, const std::vector& token_nums_per_expert, const int token_nums_this_rank, const std::string& moe_quant_type) { @@ -516,7 +516,7 @@ std::vector EPMoeExpertDispatch( topk_ids, topk_weights, num_experts_per_rank_tensor, - ffn1_in_scale, + up_gate_proj_in_scale, moe_quant_type, moe_topk, num_rows, @@ -536,7 +536,7 @@ std::vector EPMoeExpertDispatch( topk_ids, topk_weights, num_experts_per_rank_tensor, - ffn1_in_scale, + up_gate_proj_in_scale, moe_quant_type, moe_topk, num_rows, @@ -568,7 +568,7 @@ std::vector> EPMoeExpertDispatchInferShape( const std::vector& input_shape, const std::vector& topk_ids_shape, const std::vector& topk_weights_shape, - const paddle::optional>& ffn1_in_scale_dtype, + const paddle::optional>& up_gate_proj_in_scale_dtype, const std::vector& token_nums_per_expert, const int token_nums_this_rank) { int token_rows = -1; @@ -610,7 +610,7 @@ std::vector EPMoeExpertDispatchInferDtype( PD_BUILD_STATIC_OP(ep_moe_expert_dispatch) .Inputs({"input", "topk_ids", "topk_weights", - paddle::Optional("ffn1_in_scale")}) + paddle::Optional("up_gate_proj_in_scale")}) .Outputs({"permute_input", "permute_indices_per_token", "token_nums_per_expert_cumsum", diff --git a/custom_ops/gpu_ops/moe/fused_moe.cu b/custom_ops/gpu_ops/moe/fused_moe.cu index 0b4104860..a09bfa9e7 100644 --- a/custom_ops/gpu_ops/moe/fused_moe.cu +++ b/custom_ops/gpu_ops/moe/fused_moe.cu @@ -54,12 +54,12 @@ void compute_total_rows_before_expert(int* sorted_indices, template void FusedMoeKernel(const paddle::Tensor& input, const paddle::Tensor& gate_weight, - const paddle::Tensor& ffn1_weight, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn1_bias, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn2_bias, + const paddle::Tensor& up_gate_proj_weight, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& up_gate_proj_bias, + const paddle::Tensor& down_proj_weight, + const paddle::optional& down_proj_scale, + const paddle::optional& down_proj_bias, const std::string& quant_method, const int moe_topk, const bool group_moe, @@ -84,12 +84,12 @@ void FusedMoeKernel(const paddle::Tensor& input, moe_compute.ComputeFFN(&input, &gate_weight, - &ffn1_weight, - ffn1_scale ? ffn1_scale.get_ptr() : nullptr, - ffn1_bias ? ffn1_bias.get_ptr() : nullptr, - &ffn2_weight, - ffn2_scale ? ffn2_scale.get_ptr() : nullptr, - ffn2_bias ? ffn2_bias.get_ptr() : nullptr, + &up_gate_proj_weight, + up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr, + up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr, + &down_proj_weight, + down_proj_scale ? down_proj_scale.get_ptr() : nullptr, + down_proj_bias ? down_proj_bias.get_ptr() : nullptr, nullptr, moe_topk, group_moe, @@ -102,12 +102,12 @@ void FusedMoeKernel(const paddle::Tensor& input, paddle::Tensor FusedExpertMoeFunc( const paddle::Tensor& input, const paddle::Tensor& gate_weight, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_bias, - const paddle::optional& ffn2_scale, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_bias, + const paddle::optional& down_proj_scale, const std::string& quant_method, const int moe_topk, const bool norm_topk_prob, @@ -119,12 +119,12 @@ paddle::Tensor FusedExpertMoeFunc( case paddle::DataType::BFLOAT16: FusedMoeKernel(input, gate_weight, - ffn1_weight, - ffn1_scale, - ffn1_bias, - ffn2_weight, - ffn2_scale, - ffn2_bias, + up_gate_proj_weight, + up_gate_proj_scale, + up_gate_proj_bias, + down_proj_weight, + down_proj_scale, + down_proj_bias, quant_method, moe_topk, group_moe, @@ -134,12 +134,12 @@ paddle::Tensor FusedExpertMoeFunc( case paddle::DataType::FLOAT16: FusedMoeKernel(input, gate_weight, - ffn1_weight, - ffn1_scale, - ffn1_bias, - ffn2_weight, - ffn2_scale, - ffn2_bias, + up_gate_proj_weight, + up_gate_proj_scale, + up_gate_proj_bias, + down_proj_weight, + down_proj_scale, + down_proj_bias, quant_method, moe_topk, group_moe, @@ -155,24 +155,24 @@ paddle::Tensor FusedExpertMoeFunc( std::vector FusedExpertMoe( const paddle::Tensor& input, const paddle::Tensor& gate_weight, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_bias, - const paddle::optional& ffn2_scale, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_bias, + const paddle::optional& down_proj_scale, const std::string& quant_method, const int moe_topk, const bool norm_topk_prob, const bool group_moe) { return {FusedExpertMoeFunc(input, gate_weight, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_bias, - ffn2_scale, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_bias, + down_proj_scale, quant_method, moe_topk, norm_topk_prob, @@ -182,30 +182,30 @@ std::vector FusedExpertMoe( std::vector> FusedExpertMoeInferShape( const std::vector& input_shape, const std::vector& gate_weight_shape, - const std::vector& ffn1_weight_shape, - const std::vector& ffn2_weight_shape, - const paddle::optional>& ffn1_bias_shape, - const paddle::optional>& ffn1_scale_shape, - const paddle::optional>& ffn2_bias_shape, - const paddle::optional>& ffn2_scale_shape) { + const std::vector& up_gate_proj_weight_shape, + const std::vector& down_proj_weight_shape, + const paddle::optional>& up_gate_proj_bias_shape, + const paddle::optional>& up_gate_proj_scale_shape, + const paddle::optional>& down_proj_bias_shape, + const paddle::optional>& down_proj_scale_shape) { return {input_shape}; } std::vector FusedExpertMoeInferDtype( const paddle::DataType& input_dtype, const paddle::DataType& gate_weight_dtype, - const paddle::DataType& ffn1_weight_dtype, - const paddle::DataType& ffn2_weight_dtype, - const paddle::optional& ffn1_bias_dtype, - const paddle::optional& ffn1_scale_dtype, - const paddle::optional& ffn2_bias_dtype, - const paddle::optional& ffn2_scale_dtype) { + const paddle::DataType& up_gate_proj_weight_dtype, + const paddle::DataType& down_proj_weight_dtype, + const paddle::optional& up_gate_proj_bias_dtype, + const paddle::optional& up_gate_proj_scale_dtype, + const paddle::optional& down_proj_bias_dtype, + const paddle::optional& down_proj_scale_dtype) { return {input_dtype}; } /** * @brief Fused Mixture-of-Experts (MoE) Operator - * + * * This operator combines three key MoE operations into a single optimized kernel: * 1. moe_dispatch - Routes tokens to top-k experts using gating network * 2. moe_ffn - Processes tokens through parallel expert FFNs @@ -230,12 +230,12 @@ std::vector FusedExpertMoeInferDtype( PD_BUILD_STATIC_OP(fused_expert_moe) .Inputs({"input", "gate_weight", - "ffn1_weight", - "ffn2_weight", - paddle::Optional("ffn1_bias"), - paddle::Optional("ffn1_scale"), - paddle::Optional("ffn2_bias"), - paddle::Optional("ffn2_scale")}) + "up_gate_proj_weight", + "down_proj_weight", + paddle::Optional("up_gate_proj_bias"), + paddle::Optional("up_gate_proj_scale"), + paddle::Optional("down_proj_bias"), + paddle::Optional("down_proj_scale")}) .Outputs({"output"}) .Attrs({"quant_method:std::string", "moe_topk:int", diff --git a/custom_ops/gpu_ops/moe/fused_moe_helper.h b/custom_ops/gpu_ops/moe/fused_moe_helper.h index 6af1ab41a..22bf0f1f9 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_helper.h +++ b/custom_ops/gpu_ops/moe/fused_moe_helper.h @@ -117,18 +117,18 @@ public: void ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight, - const paddle::Tensor *ffn1_weight, - const paddle::Tensor *ffn1_scale, const paddle::Tensor *ffn1_bias, - const paddle::Tensor *ffn2_weight, - const paddle::Tensor *ffn2_scale, const paddle::Tensor *ffn2_bias, + const paddle::Tensor *up_gate_proj_weight, + const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias, + const paddle::Tensor *down_proj_weight, + const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias, const paddle::Tensor *moe_token_type_ids, const int moe_topk, const bool group_moe, const bool norm_topk_prob, const float routed_scaling_factor, const std::string moe_type, paddle::Tensor *output) { auto *input_activations = input->data(); auto *gating_weights = gate_weight->data(); - const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data() : nullptr; - const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data() : nullptr; + const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr; + const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data() : nullptr; auto *output_ = output->data(); auto stream = input->stream(); @@ -136,7 +136,7 @@ public: auto input_type = input->dtype(); auto input_dims = input->dims(); - auto ffn1_dims = ffn1_weight->dims(); + auto up_gate_proj_dims = up_gate_proj_weight->dims(); int64_t token_num = 0; if (input_dims.size() == 3) { token_num = input_dims[0] * input_dims[1]; @@ -145,12 +145,12 @@ public: } const int64_t num_rows = token_num; - const int64_t hidden_size = ffn1_dims[1]; + const int64_t hidden_size = up_gate_proj_dims[1]; int64_t inter_dim = 0; if (moe_type == "qkv") { - inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4]; + inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4]; } else { - inter_dim = ffn1_dims[2]; + inter_dim = up_gate_proj_dims[2]; } if (gemm_method_ == "weight_only_int4") { @@ -158,7 +158,7 @@ public: } const int64_t inter_size = inter_dim; - const int64_t num_experts = ffn1_dims[0]; + const int64_t num_experts = up_gate_proj_dims[0]; const int64_t k = moe_topk; int64_t bytes = @@ -260,38 +260,38 @@ public: total_rows_before_expert_, stream); if (gemm_method_ == "weight_only_int8") { - typename Int8Traits::Arguments ffn1_quant_args; + typename Int8Traits::Arguments up_gate_proj_quant_args; int8_moe_gemm_runner_->moe_gemm_bias_act( reinterpret_cast(permuted_data_), - reinterpret_cast(ffn1_weight->data()), - reinterpret_cast(ffn1_scale->data()), + reinterpret_cast(up_gate_proj_weight->data()), + reinterpret_cast(up_gate_proj_scale->data()), reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), total_rows_before_expert_, -1, // useless expanded_active_expert_rows, inter_size, hidden_size, num_experts, - ffn1_quant_args, "none", stream); + up_gate_proj_quant_args, "none", stream); } else if (gemm_method_ == "weight_only_int4") { - typename Int4Traits::Arguments ffn1_quant_args; + typename Int4Traits::Arguments up_gate_proj_quant_args; int4_moe_gemm_runner_->moe_gemm_bias_act( reinterpret_cast(permuted_data_), reinterpret_cast( - ffn1_weight->data()), - reinterpret_cast(ffn1_scale->data()), + up_gate_proj_weight->data()), + reinterpret_cast(up_gate_proj_scale->data()), reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), total_rows_before_expert_, -1, // useless expanded_active_expert_rows, inter_size, hidden_size, num_experts, - ffn1_quant_args, "none", stream); + up_gate_proj_quant_args, "none", stream); } else { - typename Fp16Traits::Arguments ffn1_quant_args; + typename Fp16Traits::Arguments up_gate_proj_quant_args; fp16_moe_gemm_runner_->moe_gemm_bias_act( reinterpret_cast(permuted_data_), - reinterpret_cast(ffn1_weight->data()), nullptr, + reinterpret_cast(up_gate_proj_weight->data()), nullptr, reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), total_rows_before_expert_, -1, // useless expanded_active_expert_rows, inter_size, hidden_size, num_experts, - ffn1_quant_args, "none", stream); + up_gate_proj_quant_args, "none", stream); } if (moe_type == "ffn") { @@ -304,35 +304,35 @@ public: T *fc2_result = fc2_output_tensor.data(); if (gemm_method_ == "weight_only_int8") { - typename Int8Traits::Arguments ffn2_quant_args; + typename Int8Traits::Arguments down_proj_quant_args; int8_moe_gemm_runner_->moe_gemm( reinterpret_cast(act_out), - reinterpret_cast(ffn2_weight->data()), - reinterpret_cast(ffn2_scale->data()), + reinterpret_cast(down_proj_weight->data()), + reinterpret_cast(down_proj_scale->data()), reinterpret_cast(fc2_result), total_rows_before_expert_, -1, // useless expanded_active_expert_rows, hidden_size, inter_size / 2, - num_experts, ffn2_quant_args, stream); + num_experts, down_proj_quant_args, stream); } else if (gemm_method_ == "weight_only_int4") { - typename Int4Traits::Arguments ffn2_quant_args; + typename Int4Traits::Arguments down_proj_quant_args; int4_moe_gemm_runner_->moe_gemm( reinterpret_cast(act_out), reinterpret_cast( - ffn2_weight->data()), - reinterpret_cast(ffn2_scale->data()), + down_proj_weight->data()), + reinterpret_cast(down_proj_scale->data()), reinterpret_cast(fc2_result), total_rows_before_expert_, -1, // useless expanded_active_expert_rows, hidden_size, inter_size / 2, - num_experts, ffn2_quant_args, stream); + num_experts, down_proj_quant_args, stream); } else { - typename Fp16Traits::Arguments ffn2_quant_args; + typename Fp16Traits::Arguments down_proj_quant_args; fp16_moe_gemm_runner_->moe_gemm( reinterpret_cast(act_out), - reinterpret_cast(ffn2_weight->data()), nullptr, + reinterpret_cast(down_proj_weight->data()), nullptr, reinterpret_cast(fc2_result), total_rows_before_expert_, -1, // useless expanded_active_expert_rows, hidden_size, inter_size / 2, - num_experts, ffn2_quant_args, stream); + num_experts, down_proj_quant_args, stream); } finalize_moe_routing_kernelLauncher::run( diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index dfb66640d..1d453466d 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -24,12 +24,12 @@ template void MoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn2_in_scale, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, paddle::Tensor ffn_out, @@ -51,11 +51,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2); - const int num_experts = ffn1_weight.dims()[0]; + const int num_experts = up_gate_proj_weight.dims()[0]; const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1]; - assert(ffn1_weight.dims().size() == 3); - int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size; + assert(up_gate_proj_weight.dims().size() == 3); + int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size; constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k Allocator* allocator = paddle::GetAllocator(place); @@ -96,8 +96,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, using NvType = typename traits_::DataType; auto fc1_expert_biases = - ffn1_bias - ? const_cast(ffn1_bias.get_ptr())->data() + up_gate_proj_bias + ? const_cast(up_gate_proj_bias.get_ptr())->data() : nullptr; // This is a trick. @@ -112,9 +112,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typename cutlass::WintQuantTraits::Arguments quant_args; int8_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permute_input.data()), - reinterpret_cast(ffn1_weight.data()), + reinterpret_cast(up_gate_proj_weight.data()), reinterpret_cast( - const_cast(ffn1_scale.get_ptr()) + const_cast(up_gate_proj_scale.get_ptr()) ->data()), reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), @@ -132,9 +132,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, int4_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permute_input.data()), reinterpret_cast( - ffn1_weight.data()), + up_gate_proj_weight.data()), reinterpret_cast( - const_cast(ffn1_scale.get_ptr()) + const_cast(up_gate_proj_scale.get_ptr()) ->data()), reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), @@ -151,12 +151,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, w4a8_moe_gemm_runner.moe_gemm( reinterpret_cast(permute_input.data()), reinterpret_cast( - ffn1_weight.data()), + up_gate_proj_weight.data()), quant_mode, reinterpret_cast( - const_cast(ffn1_scale.get_ptr()) + const_cast(up_gate_proj_scale.get_ptr()) ->data()), - nullptr, // ffn1_scale_dyquant + nullptr, // up_gate_proj_scale_dyquant nullptr, // nf4_look_up_table reinterpret_cast(fc1_out), const_cast(tokens_expert_prefix_sum.data()), @@ -172,7 +172,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typename cutlass::WintQuantTraits::Arguments quant_args; fp16_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permute_input.data()), - reinterpret_cast(ffn1_weight.data()), + reinterpret_cast(up_gate_proj_weight.data()), nullptr, reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), @@ -199,9 +199,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typename cutlass::WintQuantTraits::Arguments quant_args; int8_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), - reinterpret_cast(ffn2_weight.data()), + reinterpret_cast(down_proj_weight.data()), reinterpret_cast( - const_cast(ffn2_scale.get_ptr()) + const_cast(down_proj_scale.get_ptr()) ->data()), reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), @@ -218,9 +218,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, int4_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), reinterpret_cast( - ffn2_weight.data()), + down_proj_weight.data()), reinterpret_cast( - const_cast(ffn2_scale.get_ptr()) + const_cast(down_proj_scale.get_ptr()) ->data()), reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), @@ -232,17 +232,17 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, quant_args, stream); } else if (quant_method == "w4a8") { - data_t *ffn2_shift = nullptr; - data_t *ffn2_smooth = nullptr; + data_t *down_proj_shift = nullptr; + data_t *down_proj_smooth = nullptr; Allocator::AllocationPtr int8_act_out; int8_act_out = allocator->Allocate( SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); MoeFastHardamardWrapper( act_out_tensor.data(), expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - ffn2_shift, // ffn2_shift->data(), - ffn2_smooth, // ffn2_smooth->data(), - ffn2_in_scale ? const_cast(ffn2_in_scale.get_ptr())->data() : nullptr, + down_proj_shift, // down_proj_shift->data(), + down_proj_smooth, // down_proj_smooth->data(), + down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, 1, 127.0, -127.0, @@ -254,12 +254,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, w4a8_moe_gemm_runner.moe_gemm( reinterpret_cast(int8_act_out->ptr()), reinterpret_cast( - ffn2_weight.data()), + down_proj_weight.data()), quant_mode, reinterpret_cast( - const_cast(ffn2_scale.get_ptr()) + const_cast(down_proj_scale.get_ptr()) ->data()), - nullptr, // ffn2_scale_dyquant + nullptr, // down_proj_scale_dyquant nullptr, // reinterpret_cast(d_nf4_look_up_table), // nf4_look_up_table reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), @@ -275,7 +275,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typename cutlass::WintQuantTraits::Arguments quant_args; fp16_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), - reinterpret_cast(ffn2_weight.data()), + reinterpret_cast(down_proj_weight.data()), nullptr, reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), @@ -292,29 +292,29 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn2_in_scale, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency) { cudaCheckError(); - const auto t_type = quant_method == "w4a8" ? ffn1_scale.get().dtype() : permute_input.dtype(); + const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype(); auto ffn_out = paddle::empty_like(permute_input, t_type); switch (t_type) { case paddle::DataType::BFLOAT16: MoeFFNKernel(permute_input, tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_scale, - ffn2_in_scale, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + down_proj_in_scale, expert_idx_per_token, quant_method, ffn_out, used_in_ep_low_latency); @@ -322,12 +322,12 @@ paddle::Tensor MoeExpertFFNFunc( case paddle::DataType::FLOAT16: MoeFFNKernel(permute_input, tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_scale, - ffn2_in_scale, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + down_proj_in_scale, expert_idx_per_token, quant_method, ffn_out, used_in_ep_low_latency); @@ -341,22 +341,22 @@ paddle::Tensor MoeExpertFFNFunc( std::vector MoeExpertFFN( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn2_in_scale, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency) { return {MoeExpertFFNFunc(permute_input, tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_scale, - ffn2_in_scale, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + down_proj_in_scale, expert_idx_per_token, quant_method, used_in_ep_low_latency)}; } @@ -364,12 +364,12 @@ std::vector MoeExpertFFN( std::vector> MoeExpertFFNInferShape( const std::vector& permute_input_shape, const std::vector& tokens_expert_prefix_sum_shape, - const std::vector& ffn1_weight_shape, - const std::vector& ffn2_weight_shape, - const paddle::optional>& ffn1_bias_shape, - const paddle::optional>& ffn1_scale_shape, - const paddle::optional>& ffn2_scale_shape, - const paddle::optional>& ffn2_in_scale_shape, + const std::vector& up_gate_proj_weight_shape, + const std::vector& down_proj_weight_shape, + const paddle::optional>& up_gate_proj_bias_shape, + const paddle::optional>& up_gate_proj_scale_shape, + const paddle::optional>& down_proj_scale_shape, + const paddle::optional>& down_proj_in_scale_shape, const paddle::optional>& expert_idx_per_token_shape, const std::string& quant_method, const bool used_in_ep_low_latency) { @@ -379,15 +379,15 @@ std::vector> MoeExpertFFNInferShape( std::vector MoeExpertFFNInferDtype( const paddle::DataType &permute_input_dtype, const paddle::DataType &tokens_expert_prefix_sum_dtype, - const paddle::DataType &ffn1_weight_dtype, - const paddle::DataType &ffn2_weight_dtype, - const paddle::optional &ffn1_bias_dtype, - const paddle::optional &ffn1_scale_dtype, - const paddle::optional &ffn2_scale_dtype, - const paddle::optional &ffn2_in_scale_dtype, + const paddle::DataType &up_gate_proj_weight_dtype, + const paddle::DataType &down_proj_weight_dtype, + const paddle::optional &up_gate_proj_bias_dtype, + const paddle::optional &up_gate_proj_scale_dtype, + const paddle::optional &down_proj_scale_dtype, + const paddle::optional &down_proj_in_scale_dtype, const std::string &quant_method, const bool used_in_ep_low_latency) { if (quant_method == "w4a8") { - return {ffn1_scale_dtype.get()}; + return {up_gate_proj_scale_dtype.get()}; } else { return {permute_input_dtype}; } @@ -397,9 +397,9 @@ std::vector MoeExpertFFNInferDtype( * @brief Mixture of Experts (MoE) Feed-Forward Network Operator * * This operator performs the expert computation in MoE architecture, including: - * 1. First linear transformation (FFN1) with optional quantization + * 1. First linear transformation (up_gate_proj) with optional quantization * 2. SwiGLU activation function - * 3. Second linear transformation (FFN2) with optional quantization + * 3. Second linear transformation (down_proj) with optional quantization * * Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization. * @@ -410,22 +410,22 @@ std::vector MoeExpertFFNInferDtype( * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm * Shape: [num_experts] * dtype: int64 - * - ffn1_weight: First FFN layer weights + * - up_gate_proj_weight: First FFN layer weights * Shape: [num_experts, inter_size * 2, hidden_size] * dtype: Same as input (unquantized) or int8 (quantized) - * - ffn2_weight: Second FFN layer weights + * - down_proj_weight: Second FFN layer weights * Shape: [num_experts, hidden_size, inter_size] * dtype: Same as input (unquantized) or int8 (quantized) - * - ffn1_bias: Optional bias for first FFN layer + * - up_gate_proj_bias: Optional bias for first FFN layer * Shape: [num_experts, inter_size * 2] * dtype: Same as input - * - ffn1_scale: Quantization scales for first FFN layer + * - up_gate_proj_scale: Quantization scales for first FFN layer * Shape: [num_experts, inter_size * 2] * dtype: Same as input - * - ffn2_scale: Quantization scales for second FFN layer + * - down_proj_scale: Quantization scales for second FFN layer * Shape: [num_experts, hidden_size] * dtype: Same as input - * - ffn2_in_scale: Optional input scales for second FFN layer (w4a8 only) + * - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 only) * dtype: float32 * - expert_idx_per_token: Optional expert indices per token (w4a8 only) * Shape: [total_tokens] @@ -434,7 +434,7 @@ std::vector MoeExpertFFNInferDtype( * Outputs: * - output_tensor: Output tensor after MoE FFN computation * Shape: Same as permute_input - * dtype: Same as input (or ffn1_scale dtype for w4a8) + * dtype: Same as input (or up_gate_proj_scale dtype for w4a8) * * Attributes: * - quant_method: Quantization method to use @@ -449,12 +449,12 @@ std::vector MoeExpertFFNInferDtype( PD_BUILD_STATIC_OP(moe_expert_ffn) .Inputs({"permute_input", "tokens_expert_prefix_sum", - "ffn1_weight", - "ffn2_weight", - paddle::Optional("ffn1_bias"), - paddle::Optional("ffn1_scale"), - paddle::Optional("ffn2_scale"), - paddle::Optional("ffn2_in_scale"), + "up_gate_proj_weight", + "down_proj_weight", + paddle::Optional("up_gate_proj_bias"), + paddle::Optional("up_gate_proj_scale"), + paddle::Optional("down_proj_scale"), + paddle::Optional("down_proj_in_scale"), paddle::Optional("expert_idx_per_token")}) .Outputs({"output_tensor"}) .Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"}) diff --git a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu index fb9d2e69f..5a68c9e2f 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu @@ -23,17 +23,17 @@ template void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::Tensor* ffn1_bias, - const paddle::Tensor* ffn1_super_scale, - const paddle::Tensor* ffn2_super_scale, - const paddle::Tensor* ffn1_local_scale, - const paddle::Tensor* ffn1_code_scale, - const paddle::Tensor* ffn1_code_zp, - const paddle::Tensor* ffn2_local_scale, - const paddle::Tensor* ffn2_code_scale, - const paddle::Tensor* ffn2_code_zp, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::Tensor* up_gate_proj_bias, + const paddle::Tensor* up_gate_proj_super_scale, + const paddle::Tensor* down_proj_super_scale, + const paddle::Tensor* up_gate_proj_local_scale, + const paddle::Tensor* up_gate_proj_code_scale, + const paddle::Tensor* up_gate_proj_code_zp, + const paddle::Tensor* down_proj_local_scale, + const paddle::Tensor* down_proj_code_scale, + const paddle::Tensor* down_proj_code_zp, paddle::Tensor fc1_out, paddle::Tensor ffn_out, const int64_t total_rows_in_ll_else_minus1, @@ -46,15 +46,15 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, using WeightOnlyTraits = cutlass::WintQuantTraits; using WeightType = typename WeightOnlyTraits::WeightType; - typename WeightOnlyTraits::Arguments ffn1_quant_args; - typename WeightOnlyTraits::Arguments ffn2_quant_args; + typename WeightOnlyTraits::Arguments up_gate_proj_quant_args; + typename WeightOnlyTraits::Arguments down_proj_quant_args; if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { - ffn1_quant_args.local_scale_ptr = ffn1_local_scale->data(); - ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data(); - ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data(); - ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data(); - ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data(); - ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data(); + up_gate_proj_quant_args.local_scale_ptr = up_gate_proj_local_scale->data(); + up_gate_proj_quant_args.code_scale_ptr = up_gate_proj_code_scale->data(); + up_gate_proj_quant_args.code_zp_ptr = up_gate_proj_code_zp->data(); + down_proj_quant_args.local_scale_ptr = down_proj_local_scale->data(); + down_proj_quant_args.code_scale_ptr = down_proj_code_scale->data(); + down_proj_quant_args.code_zp_ptr = down_proj_code_zp->data(); } auto moe_gemm_runner = MoeGemmRunner(); @@ -62,9 +62,9 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permute_input.data()), - reinterpret_cast(ffn1_weight.data()), - reinterpret_cast(ffn1_super_scale ? ffn1_super_scale->data() : nullptr), - reinterpret_cast(ffn1_bias ? ffn1_bias->data() : nullptr), + reinterpret_cast(up_gate_proj_weight.data()), + reinterpret_cast(up_gate_proj_super_scale ? up_gate_proj_super_scale->data() : nullptr), + reinterpret_cast(up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr), reinterpret_cast(fc1_out.data()), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, @@ -72,7 +72,7 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, inter_size, hidden_size, num_experts, - ffn1_quant_args, + up_gate_proj_quant_args, "none", stream); @@ -85,8 +85,8 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, moe_gemm_runner.moe_gemm( reinterpret_cast(act_out.data()), - reinterpret_cast(ffn2_weight.data()), - reinterpret_cast(ffn2_super_scale ? ffn2_super_scale->data() : nullptr), + reinterpret_cast(down_proj_weight.data()), + reinterpret_cast(down_proj_super_scale ? down_proj_super_scale->data() : nullptr), reinterpret_cast(ffn_out.data()), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, @@ -94,24 +94,24 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, hidden_size, inter_size / 2, num_experts, - ffn2_quant_args, + down_proj_quant_args, stream); } template void MoeFFNWint2Kernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn1_local_scale, - const paddle::optional& ffn1_code_scale, - const paddle::optional& ffn1_code_zp, - const paddle::optional& ffn2_local_scale, - const paddle::optional& ffn2_code_scale, - const paddle::optional& ffn2_code_zp, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& up_gate_proj_local_scale, + const paddle::optional& up_gate_proj_code_scale, + const paddle::optional& up_gate_proj_code_zp, + const paddle::optional& down_proj_local_scale, + const paddle::optional& down_proj_code_scale, + const paddle::optional& down_proj_code_zp, paddle::Tensor ffn_out, bool used_in_ep_low_latency) { using namespace phi; @@ -121,12 +121,12 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input, auto place = permute_input.place(); assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2); - assert(ffn1_weight.dims().size() == 3); + assert(up_gate_proj_weight.dims().size() == 3); - const int num_experts = ffn1_weight.dims()[0]; + const int num_experts = up_gate_proj_weight.dims()[0]; const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1]; - int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size; + int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size; const int64_t inter_size = inter_dim * 4; @@ -160,17 +160,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input, WeightOnlyMoeFFNKernel( permute_input, tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - const_cast(ffn1_bias.get_ptr()), - const_cast(ffn1_scale.get_ptr()), - const_cast(ffn2_scale.get_ptr()), - const_cast(ffn1_local_scale.get_ptr()), - const_cast(ffn1_code_scale.get_ptr()), - const_cast(ffn1_code_zp.get_ptr()), - const_cast(ffn2_local_scale.get_ptr()), - const_cast(ffn2_code_scale.get_ptr()), - const_cast(ffn2_code_zp.get_ptr()), + up_gate_proj_weight, + down_proj_weight, + const_cast(up_gate_proj_bias.get_ptr()), + const_cast(up_gate_proj_scale.get_ptr()), + const_cast(down_proj_scale.get_ptr()), + const_cast(up_gate_proj_local_scale.get_ptr()), + const_cast(up_gate_proj_code_scale.get_ptr()), + const_cast(up_gate_proj_code_zp.get_ptr()), + const_cast(down_proj_local_scale.get_ptr()), + const_cast(down_proj_code_scale.get_ptr()), + const_cast(down_proj_code_zp.get_ptr()), fc1_out_tensor, ffn_out, total_rows_in_ll_else_minus1, @@ -184,17 +184,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input, paddle::Tensor MoeExpertFFNWint2Func( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn1_local_scale, - const paddle::optional& ffn1_code_scale, - const paddle::optional& ffn1_code_zp, - const paddle::optional& ffn2_local_scale, - const paddle::optional& ffn2_code_scale, - const paddle::optional& ffn2_code_zp, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& up_gate_proj_local_scale, + const paddle::optional& up_gate_proj_code_scale, + const paddle::optional& up_gate_proj_code_zp, + const paddle::optional& down_proj_local_scale, + const paddle::optional& down_proj_code_scale, + const paddle::optional& down_proj_code_zp, const bool used_in_ep_low_latency) { const auto dtype = permute_input.dtype(); @@ -204,34 +204,34 @@ paddle::Tensor MoeExpertFFNWint2Func( case paddle::DataType::BFLOAT16: MoeFFNWint2Kernel(permute_input, tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_scale, - ffn1_local_scale, - ffn1_code_scale, - ffn1_code_zp, - ffn2_local_scale, - ffn2_code_scale, - ffn2_code_zp, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + up_gate_proj_local_scale, + up_gate_proj_code_scale, + up_gate_proj_code_zp, + down_proj_local_scale, + down_proj_code_scale, + down_proj_code_zp, ffn_out, used_in_ep_low_latency); break; case paddle::DataType::FLOAT16: MoeFFNWint2Kernel(permute_input, tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_scale, - ffn1_local_scale, - ffn1_code_scale, - ffn1_code_zp, - ffn2_local_scale, - ffn2_code_scale, - ffn2_code_zp, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + up_gate_proj_local_scale, + up_gate_proj_code_scale, + up_gate_proj_code_zp, + down_proj_local_scale, + down_proj_code_scale, + down_proj_code_zp, ffn_out, used_in_ep_low_latency); break; @@ -244,49 +244,49 @@ paddle::Tensor MoeExpertFFNWint2Func( std::vector MoeExpertFFNWint2( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn1_local_scale, - const paddle::optional& ffn1_code_scale, - const paddle::optional& ffn1_code_zp, - const paddle::optional& ffn2_local_scale, - const paddle::optional& ffn2_code_scale, - const paddle::optional& ffn2_code_zp, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_scale, + const paddle::optional& up_gate_proj_local_scale, + const paddle::optional& up_gate_proj_code_scale, + const paddle::optional& up_gate_proj_code_zp, + const paddle::optional& down_proj_local_scale, + const paddle::optional& down_proj_code_scale, + const paddle::optional& down_proj_code_zp, const bool used_in_ep_low_latency) { return {MoeExpertFFNWint2Func(permute_input, tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_scale, - ffn1_local_scale, - ffn1_code_scale, - ffn1_code_zp, - ffn2_local_scale, - ffn2_code_scale, - ffn2_code_zp, + up_gate_proj_weight, + down_proj_weight, + up_gate_proj_bias, + up_gate_proj_scale, + down_proj_scale, + up_gate_proj_local_scale, + up_gate_proj_code_scale, + up_gate_proj_code_zp, + down_proj_local_scale, + down_proj_code_scale, + down_proj_code_zp, used_in_ep_low_latency)}; } std::vector> MoeExpertFFNWint2InferShape( const std::vector& permute_input_shape, const std::vector& tokens_expert_prefix_sum_shape, - const std::vector& ffn1_weight_shape, - const std::vector& ffn2_weight_shape, - const paddle::optional>& ffn1_bias_shape, - const paddle::optional>& ffn1_scale_shape, - const paddle::optional>& ffn2_scale_shape, - const paddle::optional>& ffn1_local_scale_shape, - const paddle::optional>& ffn1_code_scale_shape, - const paddle::optional>& ffn1_code_zp_shape, - const paddle::optional>& ffn2_local_scale_shape, - const paddle::optional>& ffn2_code_scale_shape, - const paddle::optional>& ffn2_code_zp_shape, + const std::vector& up_gate_proj_weight_shape, + const std::vector& down_proj_weight_shape, + const paddle::optional>& up_gate_proj_bias_shape, + const paddle::optional>& up_gate_proj_scale_shape, + const paddle::optional>& down_proj_scale_shape, + const paddle::optional>& up_gate_proj_local_scale_shape, + const paddle::optional>& up_gate_proj_code_scale_shape, + const paddle::optional>& up_gate_proj_code_zp_shape, + const paddle::optional>& down_proj_local_scale_shape, + const paddle::optional>& down_proj_code_scale_shape, + const paddle::optional>& down_proj_code_zp_shape, const bool used_in_ep_low_latency) { return {permute_input_shape}; @@ -295,17 +295,17 @@ std::vector> MoeExpertFFNWint2InferShape( std::vector MoeExpertFFNWint2InferDtype( const paddle::DataType &permute_input_dtype, const paddle::DataType &tokens_expert_prefix_sum_dtype, - const paddle::DataType &ffn1_weight_dtype, - const paddle::DataType &ffn2_weight_dtype, - const paddle::optional &ffn1_bias_dtype, - const paddle::optional &ffn1_scale_dtype, - const paddle::optional &ffn2_scale_dtype, - const paddle::optional &ffn1_local_scale_dtype, - const paddle::optional &ffn1_code_scale_dtype, - const paddle::optional &ffn1_code_zp_dtype, - const paddle::optional &ffn2_local_scale_dtype, - const paddle::optional &ffn2_code_scale_dtype, - const paddle::optional &ffn2_code_zp_dtype, + const paddle::DataType &up_gate_proj_weight_dtype, + const paddle::DataType &down_proj_weight_dtype, + const paddle::optional &up_gate_proj_bias_dtype, + const paddle::optional &up_gate_proj_scale_dtype, + const paddle::optional &down_proj_scale_dtype, + const paddle::optional &up_gate_proj_local_scale_dtype, + const paddle::optional &up_gate_proj_code_scale_dtype, + const paddle::optional &up_gate_proj_code_zp_dtype, + const paddle::optional &down_proj_local_scale_dtype, + const paddle::optional &down_proj_code_scale_dtype, + const paddle::optional &down_proj_code_zp_dtype, const bool used_in_ep_low_latency) { return {permute_input_dtype}; @@ -315,9 +315,9 @@ std::vector MoeExpertFFNWint2InferDtype( * @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator * * This operator performs the expert computation in MoE architecture, including: - * 1. First linear transformation (FFN1) with optional quantization + * 1. First linear transformation (up_gate_proj) with optional quantization * 2. SwiGLU activation function - * 3. Second linear transformation (FFN2) with optional quantization + * 3. Second linear transformation (down_proj) with optional quantization * * Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization. * @@ -328,26 +328,26 @@ std::vector MoeExpertFFNWint2InferDtype( * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm * Shape: [num_experts] * dtype: int64 - * - ffn1_weight: First FFN layer weights + * - up_gate_proj_weight: First FFN layer weights * Shape: [num_experts, inter_size * 2, hidden_size] * dtype: Same as input (unquantized) or int8 (quantized) - * - ffn2_weight: Second FFN layer weights + * - down_proj_weight: Second FFN layer weights * Shape: [num_experts, hidden_size, inter_size] * dtype: Same as input (unquantized) or int8 (quantized) - * - ffn1_bias: Optional bias for first FFN layer + * - up_gate_proj_bias: Optional bias for first FFN layer * Shape: [num_experts, inter_size * 2] * dtype: Same as input - * - ffn1_scale: Quantization scales for first FFN layer + * - up_gate_proj_scale: Quantization scales for first FFN layer * Shape: [num_experts, inter_size * 2] * dtype: Same as input - * - ffn2_scale: Quantization scales for second FFN layer + * - down_proj_scale: Quantization scales for second FFN layer * Shape: [num_experts, hidden_size] * dtype: Same as input * * Outputs: * - output_tensor: Output tensor after MoE FFN computation * Shape: Same as permute_input - * dtype: Same as input (or ffn1_scale dtype for w4a8) + * dtype: Same as input (or up_gate_proj_scale dtype for w4a8) * * Attributes: * - used_in_ep_low_latency: Whether running in low latency mode @@ -359,17 +359,17 @@ std::vector MoeExpertFFNWint2InferDtype( PD_BUILD_STATIC_OP(moe_expert_ffn_wint2) .Inputs({"permute_input", "tokens_expert_prefix_sum", - "ffn1_weight", - "ffn2_weight", - paddle::Optional("ffn1_bias"), - paddle::Optional("ffn1_scale"), - paddle::Optional("ffn2_scale"), - paddle::Optional("ffn1_local_scale"), - paddle::Optional("ffn1_code_scale"), - paddle::Optional("ffn1_code_zp"), - paddle::Optional("ffn2_local_scale"), - paddle::Optional("ffn2_code_scale"), - paddle::Optional("ffn2_code_zp")}) + "up_gate_proj_weight", + "down_proj_weight", + paddle::Optional("up_gate_proj_bias"), + paddle::Optional("up_gate_proj_scale"), + paddle::Optional("down_proj_scale"), + paddle::Optional("up_gate_proj_local_scale"), + paddle::Optional("up_gate_proj_code_scale"), + paddle::Optional("up_gate_proj_code_zp"), + paddle::Optional("down_proj_local_scale"), + paddle::Optional("down_proj_code_scale"), + paddle::Optional("down_proj_code_zp")}) .Outputs({"output_tensor"}) .Attrs({"used_in_ep_low_latency:bool"}) .SetKernelFn(PD_KERNEL(MoeExpertFFNWint2)) diff --git a/custom_ops/gpu_ops/moe/moe_reduce.cu b/custom_ops/gpu_ops/moe/moe_reduce.cu index ecbd25af7..e10bf9121 100644 --- a/custom_ops/gpu_ops/moe/moe_reduce.cu +++ b/custom_ops/gpu_ops/moe/moe_reduce.cu @@ -25,7 +25,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight, const paddle::Tensor &permute_indices_per_token, const paddle::Tensor &top_k_indices, - const paddle::optional &ffn2_bias, + const paddle::optional &down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor, const int num_rows, const int hidden_size, const int topk, @@ -38,7 +38,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out, finalize_moe_routing_kernelLauncher::run( ffn_out.data(), output->data(), - ffn2_bias ? ffn2_bias->data() : nullptr, + down_proj_bias ? down_proj_bias->data() : nullptr, top_k_weight.data(), permute_indices_per_token.data(), top_k_indices.data(), num_rows, hidden_size, topk, static_cast(1), norm_topk_prob, routed_scaling_factor, stream); @@ -48,7 +48,7 @@ paddle::Tensor MoeExpertReduceFunc( const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight, const paddle::Tensor &permute_indices_per_token, const paddle::Tensor &top_k_indices, - const paddle::optional &ffn2_bias, + const paddle::optional &down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { const auto input_type = ffn_out.dtype(); auto place = ffn_out.place(); @@ -63,13 +63,13 @@ paddle::Tensor MoeExpertReduceFunc( case paddle::DataType::BFLOAT16: MoeReduceKernel( ffn_out, top_k_weight, permute_indices_per_token, top_k_indices, - ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size, + down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size, topk, &output); break; case paddle::DataType::FLOAT16: MoeReduceKernel( ffn_out, top_k_weight, permute_indices_per_token, top_k_indices, - ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size, + down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size, topk, &output); break; default: @@ -83,10 +83,10 @@ MoeExpertReduce(const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight, const paddle::Tensor &permute_indices_per_token, const paddle::Tensor &top_k_indices, - const paddle::optional &ffn2_bias, + const paddle::optional &down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { return {MoeExpertReduceFunc(ffn_out, top_k_weight, permute_indices_per_token, - top_k_indices, ffn2_bias, norm_topk_prob, + top_k_indices, down_proj_bias, norm_topk_prob, routed_scaling_factor)}; } @@ -95,7 +95,7 @@ std::vector> MoeExpertReduceInferShape( const std::vector &top_k_weight_shape, const std::vector &permute_indices_per_token_shape, const std::vector &top_k_indices_shape, - const paddle::optional> &ffn2_bias_shape) { + const paddle::optional> &down_proj_bias_shape) { const int moe_topk = top_k_indices_shape[1]; auto out_shape = ffn_out_shape; if (out_shape[0] != -1) out_shape[0] /= moe_topk; @@ -107,19 +107,19 @@ std::vector MoeExpertReduceInferDtype( const paddle::DataType &top_k_weight_dtype, const paddle::DataType &permute_indices_per_token_dtype, const paddle::DataType &top_k_indices_dtype, - const paddle::optional &ffn2_bias_dtype) { + const paddle::optional &down_proj_bias_dtype) { return {ffn_out_dtype}; } /** * @brief Mixture of Experts (MoE) Expert Reduce Operator - * + * * This operator performs the following key functions: * 1. Combines outputs from multiple experts based on routing weights * 2. Applies optional bias and scaling to the combined output * 3. Restores the original token order from permuted expert outputs - * + * * Inputs: * - ffn_out: Outputs from all expert networks (permuted) * Shape: [total_tokens * moe_topk, hidden_size] @@ -133,19 +133,19 @@ std::vector MoeExpertReduceInferDtype( * - top_k_indices: Indices of selected top-k experts for each token * Shape: [total_tokens, moe_topk] * dtype: int32 - * - ffn2_bias: Optional bias term for expert outputs (hidden_size) - * + * - down_proj_bias: Optional bias term for expert outputs (hidden_size) + * * Outputs: * - output: Combined expert outputs in original token order * Shape: [total_tokens, hidden_size] * dtype: Same as ffn_out - * + * * Attributes: * - norm_topk_prob: Whether to normalize top-k probabilities * (true: weights sum to 1 for each token, * false: use raw weights) * - routed_scaling_factor: Scaling factor applied to top-k probabilities - * + * * Note: * - The operator expects permuted expert outputs from moe_expert_dispatch * - When norm_topk_prob is true, weights are normalized per token @@ -154,7 +154,7 @@ std::vector MoeExpertReduceInferDtype( */ PD_BUILD_STATIC_OP(moe_expert_reduce) .Inputs({"ffn_out", "top_k_weight", "permute_indices_per_token", - "top_k_indices", paddle::Optional("ffn2_bias")}) + "top_k_indices", paddle::Optional("down_proj_bias")}) .Outputs({"output"}) .Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"}) .SetKernelFn(PD_KERNEL(MoeExpertReduce)) diff --git a/custom_ops/iluvatar_ops/moe_reduce.cu b/custom_ops/iluvatar_ops/moe_reduce.cu index dda0ce44b..8e58db47d 100644 --- a/custom_ops/iluvatar_ops/moe_reduce.cu +++ b/custom_ops/iluvatar_ops/moe_reduce.cu @@ -25,7 +25,7 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out, const paddle::Tensor& top_k_weight, const paddle::Tensor& permute_indices_per_token, const paddle::Tensor& top_k_indices, - const paddle::optional& ffn2_bias, + const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor, const int num_rows, @@ -42,7 +42,7 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out, finalize_moe_routing_kernelLauncher( ffn_out.data(), output->data(), - ffn2_bias ? ffn2_bias->data() : nullptr, + down_proj_bias ? down_proj_bias->data() : nullptr, top_k_weight.data(), permute_indices_per_token.data(), top_k_indices.data(), @@ -60,7 +60,7 @@ paddle::Tensor MoeExpertReduceFunc( const paddle::Tensor& top_k_weight, const paddle::Tensor& permute_indices_per_token, const paddle::Tensor& top_k_indices, - const paddle::optional& ffn2_bias, + const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { const auto input_type = ffn_out.dtype(); @@ -79,7 +79,7 @@ paddle::Tensor MoeExpertReduceFunc( top_k_weight, permute_indices_per_token, top_k_indices, - ffn2_bias, + down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, @@ -93,7 +93,7 @@ paddle::Tensor MoeExpertReduceFunc( top_k_weight, permute_indices_per_token, top_k_indices, - ffn2_bias, + down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, @@ -112,14 +112,14 @@ std::vector MoeExpertReduce( const paddle::Tensor& top_k_weight, const paddle::Tensor& permute_indices_per_token, const paddle::Tensor& top_k_indices, - const paddle::optional& ffn2_bias, + const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { return {MoeExpertReduceFunc(ffn_out, top_k_weight, permute_indices_per_token, top_k_indices, - ffn2_bias, + down_proj_bias, norm_topk_prob, routed_scaling_factor)}; } @@ -129,7 +129,7 @@ std::vector> MoeExpertReduceInferShape( const std::vector& top_k_weight_shape, const std::vector& permute_indices_per_token_shape, const std::vector& top_k_indices_shape, - const paddle::optional>& ffn2_bias_shape) { + const paddle::optional>& down_proj_bias_shape) { return {ffn_out_shape}; } @@ -138,7 +138,7 @@ std::vector MoeExpertReduceInferDtype( const paddle::DataType& top_k_weight_dtype, const paddle::DataType& permute_indices_per_token_dtype, const paddle::DataType& top_k_indices_dtype, - const paddle::optional& ffn2_bias_dtype) { + const paddle::optional& down_proj_bias_dtype) { return {ffn_out_dtype}; } @@ -147,7 +147,7 @@ PD_BUILD_STATIC_OP(moe_expert_reduce) "top_k_weight", "permute_indices_per_token", "top_k_indices", - paddle::Optional("ffn2_bias")}) + paddle::Optional("down_proj_bias")}) .Outputs({"output"}) .Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"}) .SetKernelFn(PD_KERNEL(MoeExpertReduce)) diff --git a/custom_ops/xpu_ops/src/ops/moe_layer.cc b/custom_ops/xpu_ops/src/ops/moe_layer.cc index d7470bb87..70f4fac52 100644 --- a/custom_ops/xpu_ops/src/ops/moe_layer.cc +++ b/custom_ops/xpu_ops/src/ops/moe_layer.cc @@ -46,12 +46,12 @@ template std::vector MoeLayerKernel( const paddle::Tensor &x, const paddle::Tensor &gate_weight, const paddle::optional &gate_correction_bias, - const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight, - const paddle::optional &ffn1_bias, - const paddle::optional &ffn2_bias, - const paddle::optional &ffn1_weight_scale, - const paddle::optional &ffn2_weight_scale, - const paddle::optional &ffn2_in_scale, // not support + const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight, + const paddle::optional &up_gate_proj_bias, + const paddle::optional &down_proj_bias, + const paddle::optional &up_gate_proj_weight_scale, + const paddle::optional &down_proj_weight_scale, + const paddle::optional &down_proj_in_scale, // not support const std::string &quant_method, const int moe_top_k, const bool moe_group) { // std::cout << "[Op Debug] enter moe layer" << std::endl; @@ -66,24 +66,24 @@ std::vector MoeLayerKernel( const auto xtype = x.dtype(); auto x_dims = x.shape(); - auto ffn1_dims = ffn1_weight.shape(); + auto up_gate_proj_dims = up_gate_proj_weight.shape(); PD_CHECK(x_dims.size() == 2, "x_dims.size() shoud be 2."); - PD_CHECK(ffn1_dims.size() == 3, "ffn1_dims.size() should be 3."); - PD_CHECK(ffn2_in_scale.get_ptr() == nullptr, "ffn2_in_scale not support."); + PD_CHECK(up_gate_proj_dims.size() == 3, "up_gate_proj_dims.size() should be 3."); + PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, "down_proj_in_scale not support."); if (quant_method == "weight_only_int4") { - PD_CHECK(x_dims[1] == ffn1_dims[2] * 2, - "x_dims[1] should equal to ffn1_dims[2], (weight must be " + PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2, + "x_dims[1] should equal to up_gate_proj_dims[2], (weight must be " "[e,n,k])."); } else { - PD_CHECK(x_dims[1] == ffn1_dims[2], - "x_dims[1] should equal to ffn1_dims[2], (weight must be " + PD_CHECK(x_dims[1] == up_gate_proj_dims[2], + "x_dims[1] should equal to up_gate_proj_dims[2], (weight must be " "[e,n,k])."); } int token_num = x_dims[0]; int hidden_dim = x_dims[1]; - int expert_num = ffn1_dims[0]; - int inter_dim = ffn1_dims[1]; + int expert_num = up_gate_proj_dims[0]; + int inter_dim = up_gate_proj_dims[1]; int outer_dim = inter_dim / 2; paddle::Tensor fused_moe_out = paddle::empty_like(x); @@ -104,7 +104,7 @@ std::vector MoeLayerKernel( // input + output xftblock::Tensor xin(const_cast(x.data() + x_offset), xftblock_tx, x_mpart_shape); - + xftblock::Tensor xout(fused_moe_out.mutable_data() + x_offset, xftblock_tx, x_mpart_shape); // gate @@ -118,63 +118,63 @@ std::vector MoeLayerKernel( gate_correction_bias.get_ptr()->shape()); } - // ffn1 + ffn2 - std::shared_ptr xffn1_w, xffn2_w; + // up_gate_proj + down_proj + std::shared_ptr xup_gate_proj_w, xdown_proj_w; if (std::is_same::value) { - xffn1_w = std::make_shared( - const_cast(ffn1_weight.data()), nullptr, - const_cast(ffn1_weight_scale.get_ptr() - ? ffn1_weight_scale.get_ptr()->data() + xup_gate_proj_w = std::make_shared( + const_cast(up_gate_proj_weight.data()), nullptr, + const_cast(up_gate_proj_weight_scale.get_ptr() + ? up_gate_proj_weight_scale.get_ptr()->data() : nullptr), xftblock_tw, std::vector{expert_num, inter_dim, hidden_dim}); - xffn2_w = std::make_shared( - const_cast(ffn2_weight.data()), nullptr, - const_cast(ffn2_weight_scale.get_ptr() - ? ffn2_weight_scale.get_ptr()->data() + xdown_proj_w = std::make_shared( + const_cast(down_proj_weight.data()), nullptr, + const_cast(down_proj_weight_scale.get_ptr() + ? down_proj_weight_scale.get_ptr()->data() : nullptr), xftblock_tw, std::vector{expert_num, hidden_dim, outer_dim}); } else { - xffn1_w = std::make_shared( - const_cast(ffn1_weight.data()), nullptr, - const_cast(ffn1_weight_scale.get_ptr() - ? ffn1_weight_scale.get_ptr()->data() + xup_gate_proj_w = std::make_shared( + const_cast(up_gate_proj_weight.data()), nullptr, + const_cast(up_gate_proj_weight_scale.get_ptr() + ? up_gate_proj_weight_scale.get_ptr()->data() : nullptr), xftblock_tw, std::vector{expert_num, inter_dim, hidden_dim}); - xffn2_w = std::make_shared( - const_cast(ffn2_weight.data()), nullptr, - const_cast(ffn2_weight_scale.get_ptr() - ? ffn2_weight_scale.get_ptr()->data() + xdown_proj_w = std::make_shared( + const_cast(down_proj_weight.data()), nullptr, + const_cast(down_proj_weight_scale.get_ptr() + ? down_proj_weight_scale.get_ptr()->data() : nullptr), xftblock_tw, std::vector{expert_num, hidden_dim, outer_dim}); } - std::shared_ptr xffn1_bias; - std::shared_ptr xffn2_bias; - if (ffn1_bias.get_ptr()) { - xffn1_bias = std::make_shared( - const_cast(ffn1_bias.get_ptr()->data()), - xftblock::DataType::DT_FLOAT, ffn1_bias.get_ptr()->shape()); + std::shared_ptr xup_gate_proj_bias; + std::shared_ptr xdown_proj_bias; + if (up_gate_proj_bias.get_ptr()) { + xup_gate_proj_bias = std::make_shared( + const_cast(up_gate_proj_bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, up_gate_proj_bias.get_ptr()->shape()); } - if (ffn2_bias.get_ptr()) { - xffn2_bias = std::make_shared( - const_cast(ffn2_bias.get_ptr()->data()), - xftblock::DataType::DT_FLOAT, ffn2_bias.get_ptr()->shape()); + if (down_proj_bias.get_ptr()) { + xdown_proj_bias = std::make_shared( + const_cast(down_proj_bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, down_proj_bias.get_ptr()->shape()); } // std::cout << "[Op Debug] start init moe_ffn weight and bias" << // std::endl; MoeFFNWeight xftblock::MoeFFNWeight moe_ffn_w_struct; moe_ffn_w_struct.gate_weight = &xgate_w; - moe_ffn_w_struct.ffn_inter_weights = xffn1_w.get(); - moe_ffn_w_struct.ffn_inter_bias = xffn1_bias.get(); - moe_ffn_w_struct.ffn_outer_weights = xffn2_w.get(); - moe_ffn_w_struct.ffn_outer_bias = xffn2_bias.get(); + moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get(); + moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get(); + moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get(); + moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get(); moe_ffn_w_struct.score_bias = xgate_correct_bias.get(); // MoeFFNParam xftblock::MoeFFNParam moe_ffn_param; @@ -191,29 +191,29 @@ std::vector MoeLayerKernel( PD_CHECK(ret == 0, "xftblock::moe_ffn_block_sorted_castte_per_token failed"); } - + return {fused_moe_out}; } std::vector MoeLayer(const paddle::Tensor &x, const paddle::Tensor &gate_weight, const paddle::optional &gate_correction_bias, - const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight, - const paddle::optional &ffn1_bias, - const paddle::optional &ffn2_bias, - const paddle::optional &ffn1_weight_scale, - const paddle::optional &ffn2_weight_scale, - const paddle::optional &ffn2_in_scale, + const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight, + const paddle::optional &up_gate_proj_bias, + const paddle::optional &down_proj_bias, + const paddle::optional &up_gate_proj_weight_scale, + const paddle::optional &down_proj_weight_scale, + const paddle::optional &down_proj_in_scale, const std::string &quant_method, const int moe_top_k, const bool moe_group) { const auto x_type = x.dtype(); - const auto w_type = ffn1_weight.dtype(); + const auto w_type = up_gate_proj_weight.dtype(); #define APPLY_MOE_LAYER_KERNEL(TX, TW) \ return MoeLayerKernel( \ - x, gate_weight, gate_correction_bias, ffn1_weight, ffn2_weight, \ - ffn1_bias, ffn2_bias, ffn1_weight_scale, ffn2_weight_scale, \ - ffn2_in_scale, quant_method, moe_top_k, moe_group); + x, gate_weight, gate_correction_bias, up_gate_proj_weight, down_proj_weight, \ + up_gate_proj_bias, down_proj_bias, up_gate_proj_weight_scale, down_proj_weight_scale, \ + down_proj_in_scale, quant_method, moe_top_k, moe_group); // TODO(mayang02): how to use quant_method? if (x_type == paddle::DataType::BFLOAT16 && @@ -237,36 +237,36 @@ std::vector> MoeLayerInferShape( const std::vector &x_shape, const std::vector &gate_weight_shape, const paddle::optional> &gate_correction_bias_shape, - const std::vector &ffn1_weight_shape, - const std::vector &ffn2_weight_shape, - const paddle::optional> &ffn1_bias_shape, - const paddle::optional> &ffn2_bias_shape, - const paddle::optional> &ffn1_weight_scale_shape, - const paddle::optional> &ffn2_weight_scale_shape, - const paddle::optional> &ffn2_in_scale_shape) { + const std::vector &up_gate_proj_weight_shape, + const std::vector &down_proj_weight_shape, + const paddle::optional> &up_gate_proj_bias_shape, + const paddle::optional> &down_proj_bias_shape, + const paddle::optional> &up_gate_proj_weight_scale_shape, + const paddle::optional> &down_proj_weight_scale_shape, + const paddle::optional> &down_proj_in_scale_shape) { return {x_shape}; } std::vector MoeLayerInferDtype( const paddle::DataType &x_dtype, const paddle::DataType &gate_weight_dtype, const paddle::optional &gate_correction_bias_dtype, - const paddle::DataType &ffn1_weight_dtype, - const paddle::DataType &ffn2_weight_dtype, - const paddle::optional &ffn1_bias_dtype, - const paddle::optional &ffn2_bias_dtype, - const paddle::optional &ffn1_weight_scale_dtype, - const paddle::optional &ffn2_weight_scale_dtype, - const paddle::optional &ffn2_in_scale_dtype) { + const paddle::DataType &up_gate_proj_weight_dtype, + const paddle::DataType &down_proj_weight_dtype, + const paddle::optional &up_gate_proj_bias_dtype, + const paddle::optional &down_proj_bias_dtype, + const paddle::optional &up_gate_proj_weight_scale_dtype, + const paddle::optional &down_proj_weight_scale_dtype, + const paddle::optional &down_proj_in_scale_dtype) { return {x_dtype}; } PD_BUILD_OP(xpu_moe_layer) // fused_moe .Inputs({"x", "gate_weight", paddle::Optional("gate_correction_bias"), - "ffn1_weight", "ffn2_weight", paddle::Optional("ffn1_bias"), - paddle::Optional("ffn2_bias"), - paddle::Optional("ffn1_weight_scale"), - paddle::Optional("ffn2_weight_scale"), - paddle::Optional("ffn2_in_scale")}) + "up_gate_proj_weight", "down_proj_weight", paddle::Optional("up_gate_proj_bias"), + paddle::Optional("down_proj_bias"), + paddle::Optional("up_gate_proj_weight_scale"), + paddle::Optional("down_proj_weight_scale"), + paddle::Optional("down_proj_in_scale")}) .Outputs({"fused_moe_out"}) .Attrs({"quant_method:std::string", "moe_top_k:int", "moe_group:bool"}) .SetKernelFn(PD_KERNEL(MoeLayer)) diff --git a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py index 1b4a15621..2961d3df6 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py +++ b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py @@ -19,12 +19,10 @@ from paddle import nn from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce -from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map, - get_tensor) +from fastdeploy.model_executor.layers.quantization.quant_base import \ + QuantMethodBase from fastdeploy.utils import ceil_div -from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase - class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): """ @@ -36,9 +34,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): Triton Group Gemm to compute Fused MoE. """ self.quant_method = quant_method - self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] self.added_scale_attrs = [ - "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" + "up_gate_proj_weight_scale", "down_proj_weight_scale" ] def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: @@ -49,26 +47,26 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): """ Triton MoE create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - assert len(ffn1_weights) == layer.num_local_experts - assert len(ffn2_weights) == layer.num_local_experts + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + assert len(up_gate_proj_weights) == layer.num_local_experts + assert len(down_proj_weights) == layer.num_local_experts assert self.quant_method.name() == "wint8" - assert ffn1_weights[0].shape == [ + assert up_gate_proj_weights[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2 ] - assert ffn2_weights[0].shape == [ + assert down_proj_weights[0].shape == [ layer.moe_intermediate_size, layer.hidden_size ] - ffn1_tensor = paddle.stack(ffn1_weights, axis=0) - ffn2_tensor = paddle.stack(ffn2_weights, axis=0) + up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_tensor = paddle.stack(down_proj_weights, axis=0) if self.quant_method.name() == "wint8": max_bound = 127 elif self.quant_method.name() == "wint4": max_bound = 7 - for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]): + for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] @@ -150,10 +148,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x, - layer.moe_ffn1_weight, + layer.up_gate_proj_weight, intermediate_cache1, None, - layer.moe_ffn1_weight_scale, + layer.up_gate_proj_weight_scale, None, sorted_token_ids, expert_ids, @@ -164,17 +162,17 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): token_num * top_k, stride_am=x.strides[0], stride_ak=x.strides[1], - stride_be=layer.moe_ffn1_weight.strides[0], - stride_bk=layer.moe_ffn1_weight.strides[1], - stride_bn=layer.moe_ffn1_weight.strides[2], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], stride_cm=intermediate_cache1.strides[0], stride_cn=intermediate_cache1.strides[1], # stride_asm=-1, stride_ask=-1, - stride_bse=layer.moe_ffn1_weight_scale.strides[0], + stride_bse=layer.up_gate_proj_weight_scale.strides[0], stride_bsk=-1, - stride_bsn=layer.moe_ffn1_weight_scale.strides[1], + stride_bsn=layer.up_gate_proj_weight_scale.strides[1], group_n=-1, group_k=-1, # Meta-parameters @@ -197,10 +195,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( intermediate_cache2, - layer.moe_ffn2_weight, + layer.down_proj_weight, intermediate_cache3, None, - layer.moe_ffn2_weight_scale, + layer.down_proj_weight_scale, topk_weights, sorted_token_ids, expert_ids, @@ -211,16 +209,16 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): token_num * top_k, stride_am=intermediate_cache2.strides[0], stride_ak=intermediate_cache2.strides[1], - stride_be=layer.moe_ffn2_weight.strides[0], - stride_bk=layer.moe_ffn2_weight.strides[1], - stride_bn=layer.moe_ffn2_weight.strides[2], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], stride_cm=intermediate_cache3.strides[0], stride_cn=intermediate_cache3.strides[1], stride_asm=-1, stride_ask=-1, - stride_bse=layer.moe_ffn2_weight_scale.strides[0], + stride_bse=layer.down_proj_weight_scale.strides[0], stride_bsk=-1, - stride_bsn=layer.moe_ffn2_weight_scale.strides[1], + stride_bsn=layer.down_proj_weight_scale.strides[1], group_n=-1, group_k=-1, # Meta-parameters diff --git a/fastdeploy/model_executor/layers/backends/dcu/weight_only.py b/fastdeploy/model_executor/layers/backends/dcu/weight_only.py index f512a8850..a29403f5c 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/dcu/weight_only.py @@ -16,8 +16,8 @@ import paddle from paddle.nn.quant import weight_dequantize -from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, GPUWeightOnlyLinearMethod - +from fastdeploy.model_executor.layers.quantization.weight_only import ( + GPUWeightOnlyLinearMethod, WeightOnlyConfig) class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod): @@ -35,12 +35,12 @@ class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod): def apply(self, layer, x): dequant_out = weight_dequantize( - x=layer.linear_weight, - scale=layer.linear_weight_scale, + x=layer.weight, + scale=layer.weight_scale, algo=self.quant_config.algo, out_dtype=paddle.get_default_dtype() ) linear_out = paddle.matmul(x, dequant_out) - if layer.linear_bias is not None: - linear_out = paddle.add(linear_out, layer.linear_bias) + if layer.bias is not None: + linear_out = paddle.add(linear_out, layer.bias) return linear_out diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py index 0e37430e7..42b931956 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py @@ -50,11 +50,11 @@ class GCUFusedMoeMethod(MoEMethodBase): Paddle gcu create weight process. """ # bf16 - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0) - stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0) + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) + stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) for idx, weight_tensor in enumerate( - [stacked_ffn1_weights, stacked_ffn2_weights]): + [stacked_up_gate_proj_weights, stacked_down_proj_weights]): # shape [E, K, N] -> [E, N, K] weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1]) weight_name = self.added_weight_attrs[idx] @@ -117,16 +117,16 @@ class GCUFusedMoeMethod(MoEMethodBase): dtype=x.dtype, ) - ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None - ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None + up_gate_proj_B_scale = layer.up_gate_proj_weight_scale if enable_quant else None + up_gate_proj_B_zeros = layer.up_gate_proj_weight_zeros if enable_quant else None invoke_fused_moe_kernel( x, # input - layer.moe_ffn1_weight, # weight + layer.up_gate_proj_weight, # weight intermediate_cache1, # output None, # A_scale - ffn1_B_scale, # B_scale - ffn1_B_zeros, # B_zp + up_gate_proj_B_scale, # B_scale + up_gate_proj_B_zeros, # B_zp topk_weights, topk_indices, sorted_token_ids, @@ -154,16 +154,16 @@ class GCUFusedMoeMethod(MoEMethodBase): dtype=x.dtype, ) - ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None - ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None + down_proj_B_scale = layer.down_proj_weight_scale if enable_quant else None + down_proj_B_zeros = layer.down_proj_weight_zeros if enable_quant else None invoke_fused_moe_kernel( intermediate_cache2, # input - layer.moe_ffn2_weight, # weight + layer.down_proj_weight, # weight intermediate_cache3, # output None, # A_scale - ffn2_B_scale, # B_scale - ffn2_B_zeros, # B_zp + down_proj_B_scale, # B_scale + down_proj_B_zeros, # B_zp topk_weights, topk_indices, sorted_token_ids, @@ -251,7 +251,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): "GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}" self.added_qzeros_attrs = [ - "moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros" + "up_gate_proj_weight_zeros", "down_proj_weight_zeros" ] self.group_size = 64 @@ -265,41 +265,41 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): """ Paddle gcu process prequanted weights. """ - ffn1_expert_weight_key = layer.weight_key_map.get( - "ffn1_expert_weight_key", None) - ffn2_expert_weight_key = layer.weight_key_map.get( - "ffn2_expert_weight_key", None) - ffn1_expert_weight_scale_key = layer.weight_key_map.get( - "ffn1_expert_weight_scale_key", None) - ffn2_expert_weight_scale_key = layer.weight_key_map.get( - "ffn2_expert_weight_scale_key", None) + up_gate_proj_expert_weight_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_key", None) + down_proj_expert_weight_key = layer.weight_key_map.get( + "down_proj_expert_weight_key", None) + up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_scale_key", None) + down_proj_expert_weight_scale_key = layer.weight_key_map.get( + "down_proj_expert_weight_scale_key", None) - ffn1_weights, ffn2_weights = layer.load_experts_weight( - state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key) - # self.check(layer, ffn1_weights, ffn2_weights) - ffn1_weight_scale = [] - ffn2_weight_scale = [] + up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key) + # self.check(layer, up_gate_proj_weights, down_proj_weights) + up_gate_proj_weight_scale = [] + down_proj_weight_scale = [] for i in range(layer.num_experts): expert_idx = layer.expert_id_offset + i - ffn1_weight_scale.append( + up_gate_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn1_expert_weight_scale_key.format(expert_idx)))) - ffn2_weight_scale.append( + up_gate_proj_expert_weight_scale_key.format(expert_idx)))) + down_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn2_expert_weight_scale_key.format(expert_idx)))) + down_proj_expert_weight_scale_key.format(expert_idx)))) - ffn1_weight = paddle.stack(ffn1_weights, axis=0) - ffn2_weight = paddle.stack(ffn2_weights, axis=0) - ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0) - ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0) + up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_weight = paddle.stack(down_proj_weights, axis=0) + up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0) + down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0) name_tensor_map = { - "moe_ffn1_weight": ffn1_weight, - "moe_ffn2_weight": ffn2_weight, - "moe_ffn1_weight_scale": ffn1_weight_scale, - "moe_ffn2_weight_scale": ffn2_weight_scale + "up_gate_proj_weight": up_gate_proj_weight, + "down_proj_weight": down_proj_weight, + "up_gate_proj_weight_scale": up_gate_proj_weight_scale, + "down_proj_weight_scale": down_proj_weight_scale } for name, tensor in name_tensor_map.items(): create_and_set_parameter(layer, name, tensor) @@ -310,8 +310,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): """ Paddle cutlass create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - self.check(layer, ffn1_weights, ffn2_weights) + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + self.check(layer, up_gate_proj_weights, down_proj_weights) def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size): @@ -329,7 +329,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): ) - for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] zeros_name = self.added_qzeros_attrs[idx] @@ -365,8 +365,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): dict_ = dict(shared_dict) for k, v in dict_.items(): - weight_list[k] = v[0].to(ffn1_weights[0].place) - weight_scale_list[k] = v[1].to(ffn1_weights[0].place) + weight_list[k] = v[0].to(up_gate_proj_weights[0].place) + weight_scale_list[k] = v[1].to(up_gate_proj_weights[0].place) else: remain_weights_start_idx = 0 diff --git a/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py index bddfa93f5..d390169fd 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py @@ -38,14 +38,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod): def create_weights(self, layer): # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. - linear_weight_scale_shape = [layer.linear_weight_shape[1]] + weight_scale_shape = [layer.weight_shape[1]] - layer.linear_weight_shape.reverse() + layer.weight_shape.reverse() if self.quant_config.name() == "wint4": - layer.linear_weight_shape[0] //= 2 + layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" - layer.linear_weight_scale = layer.create_parameter( - shape=linear_weight_scale_shape, + layer.weight_scale = layer.create_parameter( + shape=weight_scale_shape, dtype=layer._dtype, is_bias=False, ) @@ -61,8 +61,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod): """ quant_weight = get_tensor(state_dict.pop(layer.weight_key)) weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) - layer.linear_weight.set_value(quant_weight) - layer.linear_weight_scale.set_value( + layer.weight.set_value(quant_weight) + layer.weight_scale.set_value( weight_scale.astype(paddle.get_default_dtype())) @@ -73,8 +73,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod): self.group_size, # group_size ) - layer.linear_weight.set_value(quanted_weight_tensor) - layer.linear_weight_scale.set_value( + layer.weight.set_value(quanted_weight_tensor) + layer.weight_scale.set_value( weight_scale_tensor.astype(paddle.get_default_dtype())) @@ -82,8 +82,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod): def apply(self, layer, x): linear_out = linear_quant( lhs=x, - rhs=layer.linear_weight, - scale=layer.linear_weight_scale, + rhs=layer.weight, + scale=layer.weight_scale, bias=None, group_size=self.group_size, ) diff --git a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py index 388eefe09..36bd87bc0 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py @@ -37,13 +37,13 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): Create weights for linear layer on XPU """ # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. - linear_weight_scale_shape = [layer.linear_weight_shape[1]] - layer.linear_weight_shape.reverse() + weight_scale_shape = [layer.weight_shape[1]] + layer.weight_shape.reverse() if self.quant_config.name() == "weight_only_int4": - layer.linear_weight_shape[0] //= 2 + layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" - layer.linear_weight_scale = layer.create_parameter( - shape=linear_weight_scale_shape, + layer.weight_scale = layer.create_parameter( + shape=weight_scale_shape, dtype="float32", is_bias=False, ) @@ -55,6 +55,6 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): """ quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu( weight, self.quant_config.algo, -1, -1) - layer.linear_weight.set_value( + layer.weight.set_value( paddle.transpose(quanted_weight_tensor, [1, 0])) - layer.linear_weight_scale.set_value(weight_scale_tensor) + layer.weight_scale.set_value(weight_scale_tensor) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index a0fb4fcc4..cc446f4bf 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -68,13 +68,13 @@ class VocabParallelEmbedding(nn.Layer): self.params_dtype: str = params_dtype if self.use_ep: - self.word_embeddings = nn.Embedding( + self.embeddings = nn.Embedding( num_embeddings, embedding_dim, ) else: if not self.column_cut: - self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding( + self.embeddings = fleet.meta_parallel.VocabParallelEmbedding( num_embeddings, embedding_dim, mp_group=fleet.get_hybrid_communicate_group(). @@ -85,13 +85,13 @@ class VocabParallelEmbedding(nn.Layer): ) else: # column cut embedding - self.word_embeddings = nn.Embedding( + self.embeddings = nn.Embedding( num_embeddings, embedding_dim // self.world_size, ) - self.word_embeddings.weight.is_distributed = True - self.word_embeddings.weight.split_axis = 1 + self.embeddings.weight.is_distributed = True + self.embeddings.weight.split_axis = 1 if not self.use_rope: self.position_embeddings = nn.Embedding( @@ -112,13 +112,12 @@ class VocabParallelEmbedding(nn.Layer): Args: state_dict (dict): A dictionary containing the checkpoint weights and biases. """ - a = state_dict[self.prefix + ".weight"] if self.tie_word_embeddings: - self.word_embeddings.weight.set_value( + self.embeddings.weight.set_value( get_tensor(state_dict[self.prefix + ".weight"]).astype( paddle.get_default_dtype())) else: - self.word_embeddings.weight.set_value( + self.embeddings.weight.set_value( get_tensor(state_dict.pop(self.prefix + ".weight")).astype( paddle.get_default_dtype())) @@ -134,10 +133,10 @@ class VocabParallelEmbedding(nn.Layer): Tensor: Embedded tensor representation of the input IDs. """ if self.use_ep: - input_embedings = self.word_embeddings(ids_remove_padding) + input_embedings = self.embeddings(ids_remove_padding) else: if self.column_cut: - input_embedings = self.word_embeddings(ids_remove_padding) + input_embedings = self.embeddings(ids_remove_padding) inputs_embeds_temp = [] paddle.distributed.all_gather( inputs_embeds_temp, @@ -148,6 +147,6 @@ class VocabParallelEmbedding(nn.Layer): ) input_embedings = paddle.concat(inputs_embeds_temp, -1) else: - input_embedings = self.word_embeddings(ids_remove_padding) + input_embedings = self.embeddings(ids_remove_padding) return input_embedings diff --git a/fastdeploy/model_executor/layers/hydra_head.py b/fastdeploy/model_executor/layers/hydra_head.py index 1e8ff64dd..2f3f026a5 100644 --- a/fastdeploy/model_executor/layers/hydra_head.py +++ b/fastdeploy/model_executor/layers/hydra_head.py @@ -14,16 +14,13 @@ # limitations under the License. """ -from paddleformers.utils.log import logger - import paddle import paddle.nn.functional as F from paddle import nn from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import ( - ColumnParallelLinear, - VocabParallelEmbedding, -) +from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear, + VocabParallelEmbedding) +from paddleformers.utils.log import logger from .utils import get_tensor @@ -130,7 +127,7 @@ class HydraHead(nn.Layer): ] ) - self.word_embeddings = VocabParallelEmbedding( + self.embeddings = VocabParallelEmbedding( vocab_size, hidden_size, mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), @@ -170,8 +167,8 @@ class HydraHead(nn.Layer): get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight")) ) - self.word_embeddings.weight.set_value( - get_tensor(state_dict.pop("word_embeddings.weight")) + self.embeddings.weight.set_value( + get_tensor(state_dict.pop("embeddings.weight")) ) def set_state_dict(self, state_dict): @@ -183,7 +180,7 @@ class HydraHead(nn.Layer): """ is_custom = True for key in state_dict.keys(): - if key != "word_embeddings.weight" and ( + if key != "embeddings.weight" and ( "hydra_mlp" in key or "hydra_head" in key ): is_custom = False @@ -207,7 +204,7 @@ class HydraHead(nn.Layer): hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens """ hydra_inputs = [hidden_states] - input_embeds = self.word_embeddings(input_ids) + input_embeds = self.embeddings(input_ids) for hydra_head_idx in range(self.hydra_num_heads): hydra_inputs.append(input_embeds) head_input = paddle.concat(hydra_inputs, axis=-1) @@ -217,4 +214,4 @@ class HydraHead(nn.Layer): _, topk_tokens = paddle.topk(probs, k=1, axis=-1) next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:] - input_embeds = self.word_embeddings(next_tokens[:, 1 + hydra_head_idx]) + input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx]) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 813489d57..324a5eed6 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -79,7 +79,7 @@ class LinearBase(nn.Layer): self._dtype = self._helper.get_default_dtype() self.weight_dtype = self._dtype - self.linear_weight_shape = [ + self.weight_shape = [ self.input_size, self.output_size, ] @@ -96,16 +96,16 @@ class LinearBase(nn.Layer): """ if self.skip_quant: self.weight_dtype = self._dtype - self.linear_weight = self.create_parameter( - shape=self.linear_weight_shape, + self.weight = self.create_parameter( + shape=self.weight_shape, dtype=self.weight_dtype, is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) - self.linear_bias = None + self.bias = None if self.with_bias: - self.linear_bias = self.create_parameter( + self.bias = self.create_parameter( shape=[self.output_size], dtype=self._dtype, is_bias=True, @@ -136,7 +136,7 @@ class LinearBase(nn.Layer): if self.fd_config.quant_config: self.quant_method.process_loaded_weights(self, weight_tensor) else: - self.linear_weight.set_value(weight_tensor) + self.weight.set_value(weight_tensor) def load_state_dict(self, state_dict: dict): """ @@ -157,7 +157,7 @@ class LinearBase(nn.Layer): if self.with_bias: bias_tensor = paddle.to_tensor( get_tensor(state_dict.pop(self.bias_key))) - self.linear_bias.set_value(bias_tensor) + self.bias.set_value(bias_tensor) def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: """ @@ -175,9 +175,9 @@ class LinearBase(nn.Layer): if self.fd_config.quant_config: linear_out = self.quant_method.apply(self, x) else: - linear_out = paddle.matmul(x, self.linear_weight) + linear_out = paddle.matmul(x, self.weight) if self.with_bias: - linear_out = paddle.add(linear_out, self.linear_bias) + linear_out = paddle.add(linear_out, self.bias) return linear_out @@ -219,7 +219,7 @@ class ReplicatedLinear(LinearBase): skip_quant=skip_quant) self.hidden_size = fd_config.model_config.hidden_size - self.linear_weight_shape = [ + self.weight_shape = [ self.input_size, self.output_size, ] @@ -272,7 +272,7 @@ class ColumnParallelLinear(LinearBase): output_size, self.nranks) # Split the output_size using TP inference. self.hidden_size = fd_config.model_config.hidden_size - self.linear_weight_shape = [ + self.weight_shape = [ self.input_size, self.output_size, ] @@ -286,26 +286,26 @@ class ColumnParallelLinear(LinearBase): """ if self.skip_quant: self.weight_dtype = self._dtype - self.linear_weight = self.create_parameter( - shape=self.linear_weight_shape, + self.weight = self.create_parameter( + shape=self.weight_shape, dtype=self.weight_dtype, is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) if self.nranks > 0: # col parallel - _set_var_distributed(self.linear_weight, split_axis=1) + _set_var_distributed(self.weight, split_axis=1) - self.linear_bias = None + self.bias = None if self.with_bias: - self.linear_bias = self.create_parameter( + self.bias = self.create_parameter( shape=[self.output_size], dtype=self._dtype, is_bias=True, ) if self.nranks > 0: # col parallel - _set_var_distributed(self.linear_bias, split_axis=1) + _set_var_distributed(self.bias, split_axis=1) # smooth quant self.linear_shift = None @@ -333,7 +333,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): skip_quant: bool = False, ): """ - Initialize the fused ffn1 Linear layer with given parameters. + Initialize the fused up_gate_proj Linear layer with given parameters. Args: fd_config (FDConfig): Inference-related parameters. @@ -443,7 +443,7 @@ class QKVParallelLinear(ColumnParallelLinear): q_tensor = get_tensor(state_dict.pop(q_weight_key)) k_tensor = get_tensor(state_dict.pop(k_weight_key)) v_tensor = get_tensor(state_dict.pop(v_weight_key)) - + if self.kv_num_heads < self.nranks: sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks sharedkv_start = sharedkv_index * self.head_dim @@ -462,7 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear): if self.fd_config.quant_config: self.quant_method.process_loaded_weights(self, weight_tensor) else: - self.linear_weight.set_value(weight_tensor) + self.weight.set_value(weight_tensor) def load_state_dict(self, state_dict: dict): """ @@ -485,7 +485,7 @@ class QKVParallelLinear(ColumnParallelLinear): if self.bias_key in state_dict.keys(): bias_tensor = paddle.to_tensor( get_tensor(state_dict.pop(self.bias_key))) - self.linear_bias.set_value(bias_tensor) + self.bias.set_value(bias_tensor) else: q_bias_key = self.bias_key.replace("qkv_proj", "q_proj") k_bias_key = self.bias_key.replace("qkv_proj", "k_proj") @@ -494,7 +494,7 @@ class QKVParallelLinear(ColumnParallelLinear): k_bias = get_tensor(state_dict.pop(k_bias_key)) v_bias = get_tensor(state_dict.pop(v_bias_key)) qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1) - self.linear_bias.set_value(qkv_bias) + self.bias.set_value(qkv_bias) class RowParallelLinear(LinearBase): @@ -554,7 +554,7 @@ class RowParallelLinear(LinearBase): self.input_size = divide(input_size, self.nranks) self.output_size = output_size - self.linear_weight_shape = [ + self.weight_shape = [ self.input_size, self.output_size, ] @@ -574,16 +574,16 @@ class RowParallelLinear(LinearBase): if self.skip_quant: self.weight_dtype = self._dtype - self.linear_weight = self.create_parameter( - shape=self.linear_weight_shape, + self.weight = self.create_parameter( + shape=self.weight_shape, dtype=self.weight_dtype, is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) - self.linear_bias = None + self.bias = None if self.with_bias: - self.linear_bias = self.create_parameter( + self.bias = self.create_parameter( shape=[self.hidden_size], dtype=self._dtype, is_bias=True, @@ -591,7 +591,7 @@ class RowParallelLinear(LinearBase): if self.nranks > 0: # row parallel - _set_var_distributed(self.linear_weight, split_axis=0) + _set_var_distributed(self.weight, split_axis=0) # smooth quant self.linear_shift = None @@ -601,7 +601,7 @@ class RowParallelLinear(LinearBase): if self.fd_config.quant_config: out = self.quant_method.apply(self, x) else: - out = paddle.matmul(x, self.linear_weight) + out = paddle.matmul(x, self.weight) if self.reduce_results and self.nranks > 1: tensor_model_parallel_all_reduce(out) diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 1fac83f89..188c25c19 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -52,11 +52,11 @@ class ParallelLMHead(nn.Layer): with_bias (bool): whether to have bias. Default: False. """ super(ParallelLMHead, self).__init__() - self.linear_weight_key: str = prefix + ".weight" + self.weight_key: str = prefix + ".weight" if with_bias: - self.linear_bias_key: Optional[str] = prefix + ".bias" + self.bias_key: Optional[str] = prefix + ".bias" else: - self.linear_bias_key: Optional[str] = None + self.bias_key: Optional[str] = None self.use_ep: bool = fd_config.parallel_config.use_ep self.column_cut = True @@ -74,26 +74,26 @@ class ParallelLMHead(nn.Layer): else: if self.column_cut: need_gather = True - self.out_linear = ColumnParallelLinear( + self.linear = ColumnParallelLinear( embedding_dim, num_embeddings, mp_group=fleet.get_hybrid_communicate_group(). get_model_parallel_group(), weight_attr=None, has_bias=True - if self.linear_bias_key is not None else False, + if self.bias_key is not None else False, gather_output=need_gather, fuse_matmul_bias=False, # False diff更小 ) else: - self.out_linear = RowParallelLinear( + self.linear = RowParallelLinear( embedding_dim, num_embeddings, mp_group=fleet.get_hybrid_communicate_group(). get_model_parallel_group(), weight_attr=None, has_bias=True - if self.linear_bias_key is not None else False, + if self.bias_key is not None else False, input_is_parallel=False, fuse_matmul_bias=False, # False diff更小 ) @@ -109,25 +109,25 @@ class ParallelLMHead(nn.Layer): if self.use_ep: self.weight.set_value( - get_tensor(state_dict.pop(self.linear_weight_key)).astype( + get_tensor(state_dict.pop(self.weight_key)).astype( paddle.get_default_dtype())) else: if self.tie_word_embeddings: - self.out_linear.weight.set_value( - get_tensor(state_dict.pop(self.linear_weight_key)).astype( + self.linear.weight.set_value( + get_tensor(state_dict.pop(self.weight_key)).astype( paddle.get_default_dtype()).transpose([1, 0])) else: weight_tensor = get_tensor( - state_dict.pop(self.linear_weight_key)).astype( + state_dict.pop(self.weight_key)).astype( paddle.get_default_dtype()) - if self.out_linear.weight.shape != weight_tensor.shape: + if self.linear.weight.shape != weight_tensor.shape: weight_tensor = weight_tensor.transpose([1, 0]) - self.out_linear.weight.set_value(weight_tensor) + self.linear.weight.set_value(weight_tensor) - if self.linear_bias_key is not None: - bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype( + if self.bias_key is not None: + bias = get_tensor(state_dict.pop(self.bias_key)).astype( paddle.get_default_dtype()) - self.out_linear.bias.set_value(bias) + self.linear.bias.set_value(bias) def forward(self, input: paddle.Tensor) -> paddle.Tensor: """ @@ -143,5 +143,5 @@ class ParallelLMHead(nn.Layer): if self.use_ep: logits = paddle.matmul(logits, self.weight) else: - logits = self.out_linear(logits) + logits = self.linear(logits) return logits diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index d06b14e1b..2ae7bb515 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -34,9 +34,9 @@ class MoEMethodBase(QuantMethodBase): self.moe_quant_type = "w16a16" else: self.quant_config = quant_config - self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] self.added_scale_attrs = [ - "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" + "up_gate_proj_weight_scale", "down_proj_weight_scale" ] self.pack_num = 1 @@ -63,14 +63,14 @@ class MoEMethodBase(QuantMethodBase): """ pass - def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights): + def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights): """ check layer is valid for this method """ - assert ffn1_weights[0].shape == [ + assert up_gate_proj_weights[0].shape == [ layer.hidden_size // self.pack_num, layer.moe_intermediate_size * 2 ] - assert ffn2_weights[0].shape == [ + assert down_proj_weights[0].shape == [ layer.moe_intermediate_size // self.pack_num, layer.hidden_size ] diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 48603cf1d..99ddb68cc 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -31,7 +31,8 @@ if current_platform.is_cuda() and not current_platform.is_dcu(): from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch, moe_expert_reduce, noaux_tc) elif current_platform.is_iluvatar(): - from fastdeploy.model_executor.ops.iluvatar import moe_expert_dispatch, moe_expert_reduce + from fastdeploy.model_executor.ops.iluvatar import (moe_expert_dispatch, + moe_expert_reduce) # used for deepseek_v3 @@ -65,11 +66,11 @@ class CutlassMoEMethod(MoEMethodBase): Paddle cutlass create weight process. """ # bf16 - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0) - stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0) + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) + stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) for idx, weight_tensor in enumerate( - [stacked_ffn1_weights, stacked_ffn2_weights]): + [stacked_up_gate_proj_weights, stacked_down_proj_weights]): weight_name = self.added_weight_attrs[idx] setattr( layer, weight_name, @@ -95,15 +96,15 @@ class CutlassMoEMethod(MoEMethodBase): return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn( permute_input, token_nums_per_expert, - layer.moe_ffn1_weight, - layer.moe_ffn2_weight, + layer.up_gate_proj_weight, + layer.down_proj_weight, None, - (layer.moe_ffn1_weight_scale if hasattr( - layer, "moe_ffn1_weight_scale") else None), - (layer.moe_ffn2_weight_scale if hasattr( - layer, "moe_ffn2_weight_scale") else None), - (layer.moe_ffn2_in_scale - if hasattr(layer, "moe_ffn2_in_scale") else None), + (layer.up_gate_proj_weight_scale if hasattr( + layer, "up_gate_proj_weight_scale") else None), + (layer.down_proj_weight_scale if hasattr( + layer, "down_proj_weight_scale") else None), + (layer.down_proj_in_scale + if hasattr(layer, "down_proj_in_scale") else None), expert_idx_per_token, self.moe_quant_type, used_in_ep_low_latency, @@ -111,15 +112,15 @@ class CutlassMoEMethod(MoEMethodBase): return fastdeploy.model_executor.ops.gpu.moe_expert_ffn( permute_input, token_nums_per_expert, - layer.moe_ffn1_weight, - layer.moe_ffn2_weight, + layer.up_gate_proj_weight, + layer.down_proj_weight, None, - (layer.moe_ffn1_weight_scale - if hasattr(layer, "moe_ffn1_weight_scale") else None), - (layer.moe_ffn2_weight_scale - if hasattr(layer, "moe_ffn2_weight_scale") else None), - (layer.moe_ffn2_in_scale - if hasattr(layer, "moe_ffn2_in_scale") else None), + (layer.up_gate_proj_weight_scale + if hasattr(layer, "up_gate_proj_weight_scale") else None), + (layer.down_proj_weight_scale + if hasattr(layer, "down_proj_weight_scale") else None), + (layer.down_proj_in_scale + if hasattr(layer, "down_proj_in_scale") else None), expert_idx_per_token, self.moe_quant_type, used_in_ep_low_latency, @@ -163,8 +164,8 @@ class CutlassMoEMethod(MoEMethodBase): recv_x, recv_topk_idx, recv_topk_weights, - (self.moe_ffn1_in_scale - if hasattr(self, "moe_ffn1_in_scale") else None), + (self.up_gate_proj_in_scale + if hasattr(self, "up_gate_proj_in_scale") else None), recv_num_tokens_per_expert_list, token_all_num, self.moe_quant_type, @@ -186,7 +187,7 @@ class CutlassMoEMethod(MoEMethodBase): dst_weights, permute_indices_per_token, dst_indices, - None, # moe_ffn2_bias, + None, # down_proj_bias, False, # norm_topk_prob 1.0, )[0] @@ -256,7 +257,7 @@ class CutlassMoEMethod(MoEMethodBase): x, gate_out, None, # Use layer.gate_correction_bias in get_moe_scores. - (layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale") + (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), # if set, permute_input will be int8_t layer.top_k, False, @@ -274,7 +275,7 @@ class CutlassMoEMethod(MoEMethodBase): x, gate_out, layer.gate_correction_bias, - (layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale") + (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), # if set, permute_input will be int8_t layer.top_k, False, @@ -323,9 +324,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): """ Paddle cutlass create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - self.check(layer, ffn1_weights, ffn2_weights) - for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + self.check(layer, up_gate_proj_weights, down_proj_weights) + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] weight_list = [] for i in range(layer.num_local_experts): @@ -366,26 +367,26 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): create_and_set_parameter(layer, name, processed_weight_scale) # 1. Init scale containers and maps - moe_ffn1_weight_scales = [] - moe_ffn2_weight_scales = [] - moe_ffn1_in_scales = [] - moe_ffn2_in_scales = [] + up_gate_proj_weight_scales = [] + down_proj_weight_scales = [] + up_gate_proj_in_scales = [] + down_proj_in_scales = [] scale_weight_map = { - "moe_ffn1_weight_scale": moe_ffn1_weight_scales, - "moe_ffn2_weight_scale": moe_ffn2_weight_scales, - "moe_ffn1_in_scale": moe_ffn1_in_scales, - "moe_ffn2_in_scale": moe_ffn2_in_scales, + "up_gate_proj_weight_scale": up_gate_proj_weight_scales, + "down_proj_weight_scale": down_proj_weight_scales, + "up_gate_proj_in_scale": up_gate_proj_in_scales, + "down_proj_in_scale": down_proj_in_scales, } scale_key_map = { - "moe_ffn1_weight_scale": - weight_key_map.get("ffn1_expert_weight_scale_key", None), - "moe_ffn2_weight_scale": - weight_key_map.get("ffn2_expert_weight_scale_key", None), - "moe_ffn1_in_scale": - weight_key_map.get("ffn1_expert_in_scale_key", None), - "moe_ffn2_in_scale": - weight_key_map.get("ffn2_expert_in_scale_key", None), + "up_gate_proj_weight_scale": + weight_key_map.get("up_gate_proj_expert_weight_scale_key", None), + "down_proj_weight_scale": + weight_key_map.get("down_proj_expert_weight_scale_key", None), + "up_gate_proj_in_scale": + weight_key_map.get("up_gate_proj_expert_in_scale_key", None), + "down_proj_in_scale": + weight_key_map.get("down_proj_expert_in_scale_key", None), } for name, value in scale_key_map.items(): if value is None: @@ -404,13 +405,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): # 3. Process scale tensor and set to layer in_scales = [] - for in_scale_name in ["moe_ffn1_in_scale", "moe_ffn2_in_scale"]: + for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: in_scales.append( _process_in_scale(in_scale_name, scale_weight_map[in_scale_name])) for i, weight_scale_name in enumerate( - ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]): + ["up_gate_proj_weight_scale", "down_proj_weight_scale"]): _process_weight_scale(weight_scale_name, scale_weight_map[weight_scale_name], in_scales[i]) @@ -431,41 +432,41 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): """ Paddle cutlass process prequanted weights. """ - ffn1_expert_weight_key = layer.weight_key_map.get( - "ffn1_expert_weight_key", None) - ffn2_expert_weight_key = layer.weight_key_map.get( - "ffn2_expert_weight_key", None) - ffn1_expert_weight_scale_key = layer.weight_key_map.get( - "ffn1_expert_weight_scale_key", None) - ffn2_expert_weight_scale_key = layer.weight_key_map.get( - "ffn2_expert_weight_scale_key", None) + up_gate_proj_expert_weight_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_key", None) + down_proj_expert_weight_key = layer.weight_key_map.get( + "down_proj_expert_weight_key", None) + up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_scale_key", None) + down_proj_expert_weight_scale_key = layer.weight_key_map.get( + "down_proj_expert_weight_scale_key", None) - ffn1_weights, ffn2_weights = layer.load_experts_weight( - state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key) - # self.check(layer, ffn1_weights, ffn2_weights) - ffn1_weight_scale = [] - ffn2_weight_scale = [] + up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key) + # self.check(layer, up_gate_proj_weights, down_proj_weights) + up_gate_proj_weight_scale = [] + down_proj_weight_scale = [] for i in range(layer.num_local_experts): expert_idx = layer.expert_id_offset + i - ffn1_weight_scale.append( + up_gate_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn1_expert_weight_scale_key.format(expert_idx)))) - ffn2_weight_scale.append( + up_gate_proj_expert_weight_scale_key.format(expert_idx)))) + down_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn2_expert_weight_scale_key.format(expert_idx)))) + down_proj_expert_weight_scale_key.format(expert_idx)))) - ffn1_weight = paddle.stack(ffn1_weights, axis=0) - ffn2_weight = paddle.stack(ffn2_weights, axis=0) - ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0) - ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0) + up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_weight = paddle.stack(down_proj_weights, axis=0) + up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0) + down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0) name_tensor_map = { - "moe_ffn1_weight": ffn1_weight, - "moe_ffn2_weight": ffn2_weight, - "moe_ffn1_weight_scale": ffn1_weight_scale, - "moe_ffn2_weight_scale": ffn2_weight_scale + "up_gate_proj_weight": up_gate_proj_weight, + "down_proj_weight": down_proj_weight, + "up_gate_proj_weight_scale": up_gate_proj_weight_scale, + "down_proj_weight_scale": down_proj_weight_scale } for name, tensor in name_tensor_map.items(): create_and_set_parameter(layer, name, tensor) @@ -474,10 +475,10 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): """ Paddle cutlass create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - self.check(layer, ffn1_weights, ffn2_weights) + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + self.check(layer, up_gate_proj_weights, down_proj_weights) - for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 14301aa44..fbdd9e7ae 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -39,11 +39,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): deepgemm create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) - self.check(layer, ffn1_weights, ffn2_weights) + self.check(layer, up_gate_proj_weights, down_proj_weights) - for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] @@ -70,41 +70,41 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): """ Paddle cutlass process prequanted weights. """ - ffn1_expert_weight_key = layer.weight_key_map.get( - "ffn1_expert_weight_key", None) - ffn2_expert_weight_key = layer.weight_key_map.get( - "ffn2_expert_weight_key", None) - ffn1_expert_weight_scale_key = layer.weight_key_map.get( - "ffn1_expert_weight_scale_key", None) - ffn2_expert_weight_scale_key = layer.weight_key_map.get( - "ffn2_expert_weight_scale_key", None) + up_gate_proj_expert_weight_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_key", None) + down_proj_expert_weight_key = layer.weight_key_map.get( + "down_proj_expert_weight_key", None) + up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_scale_key", None) + down_proj_expert_weight_scale_key = layer.weight_key_map.get( + "down_proj_expert_weight_scale_key", None) - ffn1_weights, ffn2_weights = layer.load_experts_weight( - state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key) - # self.check(layer, ffn1_weights, ffn2_weights) - ffn1_weight_scale = [] - ffn2_weight_scale = [] + up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key) + # self.check(layer, up_gate_proj_weights, down_proj_weights) + up_gate_proj_weight_scale = [] + down_proj_weight_scale = [] for i in range(layer.num_local_experts): expert_idx = layer.expert_id_offset + i - ffn1_weight_scale.append( + up_gate_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn1_expert_weight_scale_key.format(expert_idx)))) - ffn2_weight_scale.append( + up_gate_proj_expert_weight_scale_key.format(expert_idx)))) + down_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn2_expert_weight_scale_key.format(expert_idx)))) + down_proj_expert_weight_scale_key.format(expert_idx)))) - ffn1_weight = paddle.stack(ffn1_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") - ffn2_weight = paddle.stack(ffn2_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") - ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0).transpose([0, 2, 1]).contiguous() - ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0).transpose([0, 2, 1]).contiguous() + up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") + down_proj_weight = paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn") + up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous() + down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous() name_tensor_map = { - "moe_ffn1_weight": ffn1_weight, - "moe_ffn2_weight": ffn2_weight, - "moe_ffn1_weight_scale": ffn1_weight_scale, - "moe_ffn2_weight_scale": ffn2_weight_scale + "up_gate_proj_weight": up_gate_proj_weight, + "down_proj_weight": down_proj_weight, + "up_gate_proj_weight_scale": up_gate_proj_weight_scale, + "down_proj_weight_scale": down_proj_weight_scale } for name, tensor in name_tensor_map.items(): create_and_set_parameter(layer, name, tensor) @@ -143,10 +143,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): if token_all_num > 0: logger.info(f"token_all_num {token_all_num}") (recv_x, recv_x_scale) = recv_x - + token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist()) - + ( permute_input, permute_scale, @@ -171,21 +171,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): permute_scale = permute_scale.transpose([1, 0]).contiguous() permute_scale = permute_scale.transpose([1, 0]) - # ffn1 + # up_gate_proj ffn_out = paddle.empty( - (permute_input.shape[0], layer.moe_ffn1_weight.shape[1]), + (permute_input.shape[0], layer.up_gate_proj_weight.shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (permute_input, permute_scale), - (layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale), + (layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale), ffn_out, m_indices, ) # swiglu ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None) - # ffn2 + # down_proj ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( ffn_out, self.quant_config.weight_block_size[0]) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose( @@ -193,11 +193,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) ffn_out = paddle.empty( - (ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]), + (ffn_out.shape[0], layer.down_proj_weight.shape[1]), dtype=paddle.bfloat16) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (ffn_in_x, ffn_in_x_scale_tensor), - (layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale), + (layer.down_proj_weight, layer.down_proj_weight_scale), ffn_out, m_indices, ) @@ -207,7 +207,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): dst_weights, permute_indices_per_token, dst_indices, - None, # moe_ffn2_bias + None, # down_proj_bias False, # norm_topk_prob 1.0, )[0] @@ -237,7 +237,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): # 3. Compute ffn assert isinstance(permute_input, tuple) - ffn1_out = paddle.empty( + up_gate_proj_out = paddle.empty( [ layer.num_local_experts, layer.ep_size * @@ -261,16 +261,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( permute_input, ( - layer.moe_ffn1_weight, - layer.moe_ffn1_weight_scale, + layer.up_gate_proj_weight, + layer.up_gate_proj_weight_scale, ), - ffn1_out, + up_gate_proj_out, token_nums_per_expert, expected_m, ) act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked( - ffn1_out, token_nums_per_expert) + up_gate_proj_out, token_nums_per_expert) act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant( act_out, token_nums_per_expert, @@ -279,8 +279,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_fp8, scale), ( - layer.moe_ffn2_weight, - layer.moe_ffn2_weight_scale, + layer.down_proj_weight, + layer.down_proj_weight_scale, ), ffn_out, token_nums_per_expert, @@ -339,21 +339,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): permute_scale = permute_scale.transpose([1, 0]).contiguous() permute_scale = permute_scale.transpose([1, 0]) - # ffn1 + # up_gate_proj ffn_out = paddle.empty( - (permute_input.shape[0], layer.moe_ffn1_weight.shape[1]), + (permute_input.shape[0], layer.up_gate_proj_weight.shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (permute_input, permute_scale), - (layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale), + (layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale), ffn_out, m_indices, ) # swiglu ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out) - # ffn2 + # down_proj ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( ffn_out, self.quant_config.weight_block_size[0]) @@ -362,11 +362,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) ffn_out = paddle.empty( - (ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]), + (ffn_out.shape[0], layer.down_proj_weight.shape[1]), dtype=paddle.bfloat16) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (ffn_in_x, ffn_in_x_scale_tensor), - (layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale), + (layer.down_proj_weight, layer.down_proj_weight_scale), ffn_out, m_indices, ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index ceb18edf0..da308a0b8 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -103,9 +103,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): Marlin Group Gemm to compute Fused MoE. """ self.quant_method = quant_method - self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] self.added_scale_attrs = [ - "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" + "up_gate_proj_weight_scale", "down_proj_weight_scale" ] self.added_zeros_attrs = ["zeros0", "zeros1"] @@ -113,22 +113,22 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): """ Marlin MoE create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - assert len(ffn1_weights) == layer.num_local_experts - assert len(ffn2_weights) == layer.num_local_experts - assert ffn1_weights[0].shape == [ + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + assert len(up_gate_proj_weights) == layer.num_local_experts + assert len(down_proj_weights) == layer.num_local_experts + assert up_gate_proj_weights[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2 ] - assert ffn2_weights[0].shape == [ + assert down_proj_weights[0].shape == [ layer.moe_intermediate_size, layer.hidden_size ] - ffn1_tensor = paddle.stack(ffn1_weights, axis=0) - ffn2_tensor = paddle.stack(ffn2_weights, axis=0) + up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_tensor = paddle.stack(down_proj_weights, axis=0) max_bound = 7 - for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]): + for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] @@ -221,8 +221,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): ffn_out = MoeWna16MarlinGemmApi( x, c_or_none=None, - b_q_weight=layer.moe_ffn1_weight, - b_scales=layer.moe_ffn1_weight_scale, + b_q_weight=layer.up_gate_proj_weight, + b_scales=layer.up_gate_proj_weight_scale, global_scale_or_none=None, b_zeros_or_none=None, g_idx_or_none=None, @@ -250,8 +250,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): ffn_out = MoeWna16MarlinGemmApi( swiglu_out, c_or_none=None, - b_q_weight=layer.moe_ffn2_weight, - b_scales=layer.moe_ffn2_weight_scale, + b_q_weight=layer.down_proj_weight, + b_scales=layer.down_proj_weight_scale, global_scale_or_none=None, b_zeros_or_none=None, g_idx_or_none=None, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index c113fe712..512f76c81 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -30,7 +30,7 @@ try: from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func from .triton_moe_kernels import fused_moe_kernel_paddle -except: +except ImportError: pass @@ -44,9 +44,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): Triton Group Gemm to compute Fused MoE. """ self.quant_config = quant_config - self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] self.added_scale_attrs = [ - "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" + "up_gate_proj_weight_scale", "down_proj_weight_scale" ] def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: @@ -57,30 +57,30 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): """ Triton MoE create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - assert len(ffn1_weights) == layer.num_local_experts - assert len(ffn2_weights) == layer.num_local_experts + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + assert len(up_gate_proj_weights) == layer.num_local_experts + assert len(down_proj_weights) == layer.num_local_experts algo = layer.quant_method.quant_config.name() assert algo == "wint8" - assert ffn1_weights[0].shape == [ + assert up_gate_proj_weights[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2 ] - assert ffn2_weights[0].shape == [ + assert down_proj_weights[0].shape == [ layer.moe_intermediate_size, layer.hidden_size ] - ffn1_tensor = paddle.stack(ffn1_weights, axis=0) - ffn2_tensor = paddle.stack(ffn2_weights, axis=0) + up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_tensor = paddle.stack(down_proj_weights, axis=0) if algo == "wint8": max_bound = 127 elif algo == "wint4": max_bound = 7 - for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]): + for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] @@ -130,7 +130,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): True, # apply_norm_weight, False, ) - ffn1_out = paddle.empty( + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, ) @@ -150,10 +150,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x, - layer.moe_ffn1_weight, - ffn1_out, + layer.up_gate_proj_weight, + up_gate_proj_out, None, - layer.moe_ffn1_weight_scale, + layer.up_gate_proj_weight_scale, None, sorted_token_ids, expert_ids, @@ -164,17 +164,17 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): K=hidden_size, stride_am=x.strides[0], stride_ak=x.strides[1], - stride_be=layer.moe_ffn1_weight.strides[0], - stride_bk=layer.moe_ffn1_weight.strides[1], - stride_bn=layer.moe_ffn1_weight.strides[2], - stride_cm=ffn1_out.strides[0], - stride_cn=ffn1_out.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], # stride_asm=-1, stride_ask=-1, - stride_bse=layer.moe_ffn1_weight_scale.strides[0], + stride_bse=layer.up_gate_proj_weight_scale.strides[0], stride_bsk=-1, - stride_bsn=layer.moe_ffn1_weight_scale.strides[1], + stride_bsn=layer.up_gate_proj_weight_scale.strides[1], group_n=-1, group_k=-1, # Meta-parameters @@ -190,10 +190,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, ) - ffn2_input = paddle.incubate.nn.functional.swiglu( - ffn1_out) + down_proj_input = paddle.incubate.nn.functional.swiglu( + up_gate_proj_out) - ffn2_out = paddle.empty( + down_proj_out = paddle.empty( (token_num * top_k, hidden_size), dtype=x.dtype, ) @@ -202,11 +202,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( - ffn2_input, - layer.moe_ffn2_weight, - ffn2_out, + down_proj_input, + layer.down_proj_weight, + down_proj_out, None, - layer.moe_ffn2_weight_scale, + layer.down_proj_weight_scale, topk_weights, sorted_token_ids, expert_ids, @@ -215,18 +215,18 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): token_num * top_k, N=hidden_size, K=moe_intermediate_size, - stride_am=ffn2_input.strides[0], - stride_ak=ffn2_input.strides[1], - stride_be=layer.moe_ffn2_weight.strides[0], - stride_bk=layer.moe_ffn2_weight.strides[1], - stride_bn=layer.moe_ffn2_weight.strides[2], - stride_cm=ffn2_out.strides[0], - stride_cn=ffn2_out.strides[1], + stride_am=down_proj_input.strides[0], + stride_ak=down_proj_input.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], stride_asm=-1, stride_ask=-1, - stride_bse=layer.moe_ffn2_weight_scale.strides[0], + stride_bse=layer.down_proj_weight_scale.strides[0], stride_bsk=-1, - stride_bsn=layer.moe_ffn2_weight_scale.strides[1], + stride_bsn=layer.down_proj_weight_scale.strides[1], group_n=-1, group_k=-1, # Meta-parameters @@ -242,8 +242,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, ) - ffn2_out.reshape_([token_num, top_k, hidden_size]) - out = ffn2_out.sum(axis=1) + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) return out @@ -261,20 +261,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: """process_prequanted_weights""" - ffn1_tensor, ffn2_tensor = layer.extract_moe_ffn_weights(state_dict) - assert ffn1_tensor[0].shape == [ + up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict) + assert up_gate_proj_tensor[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2 ] - assert ffn2_tensor[0].shape == [ + assert down_proj_tensor[0].shape == [ layer.moe_intermediate_size, layer.hidden_size ] - ffn1_tensor = paddle.stack(ffn1_tensor, axis=0).view(paddle.float8_e4m3fn) - ffn2_tensor = paddle.stack(ffn2_tensor, axis=0).view(paddle.float8_e4m3fn) + up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn) + down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn) added_wfp8afp8_attrs = [ - "moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale", - "moe_ffn2_weight_scale", "moe_ffn1_in_scale", "moe_ffn2_in_scale" + "up_gate_proj_weight", "down_proj_weight", "up_gate_proj_weight_scale", + "down_proj_weight_scale", "up_gate_proj_in_scale", "down_proj_in_scale" ] def _extract_scale_tensor(key_template): @@ -285,18 +285,18 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): return paddle.concat(result).cast("float32") weight_key_map = layer.weight_key_map - moe_ffn1_weight_scale = _extract_scale_tensor( - weight_key_map["ffn1_expert_weight_scale_key"]) - moe_ffn2_weight_scale = _extract_scale_tensor( - weight_key_map["ffn2_expert_weight_scale_key"]) - moe_ffn1_in_scale = _extract_scale_tensor( - weight_key_map["ffn1_expert_in_scale_key"]) - moe_ffn2_in_scale = _extract_scale_tensor( - weight_key_map["ffn2_expert_in_scale_key"]) + up_gate_proj_weight_scale = _extract_scale_tensor( + weight_key_map["up_gate_proj_expert_weight_scale_key"]) + down_proj_weight_scale = _extract_scale_tensor( + weight_key_map["down_proj_expert_weight_scale_key"]) + up_gate_proj_in_scale = _extract_scale_tensor( + weight_key_map["up_gate_proj_expert_in_scale_key"]) + down_proj_in_scale = _extract_scale_tensor( + weight_key_map["down_proj_expert_in_scale_key"]) for idx, weight_tensor in enumerate([ - ffn1_tensor, ffn2_tensor, moe_ffn1_weight_scale, - moe_ffn2_weight_scale, moe_ffn1_in_scale, moe_ffn2_in_scale + up_gate_proj_tensor, down_proj_tensor, up_gate_proj_weight_scale, + down_proj_weight_scale, up_gate_proj_in_scale, down_proj_in_scale ]): name = added_wfp8afp8_attrs[idx] setattr( @@ -341,12 +341,12 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): False, ) - ffn1_out = paddle.empty( + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, ) - config_ffn1 = { + config_up_gate_proj = { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, @@ -354,15 +354,15 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): } sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( - topk_ids, num_local_experts, config_ffn1["BLOCK_SIZE_M"]) + topk_ids, num_local_experts, config_up_gate_proj["BLOCK_SIZE_M"]) max_possible_num_post_padded = sorted_token_ids.shape[0] grid = ( - ceil_div(max_possible_num_post_padded, config_ffn1["BLOCK_SIZE_M"]) * - ceil_div(moe_intermediate_size * 2, config_ffn1["BLOCK_SIZE_N"]), ) + ceil_div(max_possible_num_post_padded, config_up_gate_proj["BLOCK_SIZE_M"]) * + ceil_div(moe_intermediate_size * 2, config_up_gate_proj["BLOCK_SIZE_N"]), ) permute_x = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8( x, - scale=layer.moe_ffn1_in_scale, + scale=layer.up_gate_proj_in_scale, topk_ids=topk_ids, top_k=top_k, intermediate_size=hidden_size, @@ -370,10 +370,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( permute_x, - layer.moe_ffn1_weight, - ffn1_out, - layer.moe_ffn1_in_scale, - layer.moe_ffn1_weight_scale, + layer.up_gate_proj_weight, + up_gate_proj_out, + layer.up_gate_proj_in_scale, + layer.up_gate_proj_weight_scale, None, sorted_token_ids, expert_ids, @@ -384,11 +384,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): K=hidden_size, stride_am=x.strides[0], stride_ak=x.strides[1], - stride_be=layer.moe_ffn1_weight.strides[0], - stride_bk=layer.moe_ffn1_weight.strides[1], - stride_bn=layer.moe_ffn1_weight.strides[2], - stride_cm=ffn1_out.strides[0], - stride_cn=ffn1_out.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], # stride_asm=-1, # only used in blockwise fp8 stride_ask=-1, # only used in blockwise fp8 @@ -398,51 +398,51 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): group_n=-1, group_k=-1, # Meta-parameters - BLOCK_SIZE_M=config_ffn1["BLOCK_SIZE_M"], - BLOCK_SIZE_N=config_ffn1["BLOCK_SIZE_N"], - BLOCK_SIZE_K=config_ffn1["BLOCK_SIZE_K"], - GROUP_SIZE_M=config_ffn1["GROUP_SIZE_M"], + BLOCK_SIZE_M=config_up_gate_proj["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config_up_gate_proj["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config_up_gate_proj["BLOCK_SIZE_K"], + GROUP_SIZE_M=config_up_gate_proj["GROUP_SIZE_M"], MUL_ROUTED_WEIGHT=False, top_k=1, compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, - even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0, + even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0, ) - ffn2_input = paddle.incubate.nn.functional.swiglu( - ffn1_out) + down_proj_input = paddle.incubate.nn.functional.swiglu( + up_gate_proj_out) - ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8( - ffn2_input, - scale=layer.moe_ffn2_in_scale, + down_proj_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8( + down_proj_input, + scale=layer.down_proj_in_scale, topk_ids=topk_ids, top_k=top_k, intermediate_size=moe_intermediate_size, tiled=True) - config_ffn2 = { + config_down_proj = { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, } - ffn2_out = paddle.empty( + down_proj_out = paddle.empty( (token_num * top_k, hidden_size), dtype=x.dtype, ) grid = ( - ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) * - ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), ) + ceil_div(max_possible_num_post_padded, config_down_proj["BLOCK_SIZE_M"]) * + ceil_div(hidden_size, config_down_proj["BLOCK_SIZE_N"]), ) fused_moe_kernel_paddle[grid]( - ffn2_input, - layer.moe_ffn2_weight, - ffn2_out, - layer.moe_ffn2_in_scale, - layer.moe_ffn2_weight_scale, + down_proj_input, + layer.down_proj_weight, + down_proj_out, + layer.down_proj_in_scale, + layer.down_proj_weight_scale, topk_weights, sorted_token_ids, expert_ids, @@ -451,13 +451,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): token_num * top_k, N=hidden_size, K=moe_intermediate_size, - stride_am=ffn2_input.strides[0], - stride_ak=ffn2_input.strides[1], - stride_be=layer.moe_ffn2_weight.strides[0], - stride_bk=layer.moe_ffn2_weight.strides[1], - stride_bn=layer.moe_ffn2_weight.strides[2], - stride_cm=ffn2_out.strides[0], - stride_cn=ffn2_out.strides[1], + stride_am=down_proj_input.strides[0], + stride_ak=down_proj_input.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], stride_asm=-1, stride_ask=-1, stride_bse=-1, @@ -466,20 +466,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): group_n=-1, group_k=-1, # Meta-parameters - BLOCK_SIZE_M=config_ffn2["BLOCK_SIZE_M"], - BLOCK_SIZE_N=config_ffn2["BLOCK_SIZE_N"], - BLOCK_SIZE_K=config_ffn2["BLOCK_SIZE_K"], - GROUP_SIZE_M=config_ffn2["GROUP_SIZE_M"], + BLOCK_SIZE_M=config_down_proj["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config_down_proj["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config_down_proj["BLOCK_SIZE_K"], + GROUP_SIZE_M=config_down_proj["GROUP_SIZE_M"], MUL_ROUTED_WEIGHT=True, top_k=1, compute_type_enum=1, use_fp8_w8a8=True, use_int8_w8a16=False, - even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0, + even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0, ) - ffn2_out.reshape_([token_num, top_k, hidden_size]) - out = ffn2_out.sum(axis=1) + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) if layer.tp_size > 1: tensor_model_parallel_all_reduce(out) @@ -496,9 +496,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): Triton Group Gemm to compute Fused MoE. """ self.quant_config = quant_config - self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] self.added_scale_attrs = [ - "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" + "up_gate_proj_weight_scale", "down_proj_weight_scale" ] def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: @@ -510,11 +510,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): """ Triton MoE create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) - self.check(layer, ffn1_weights, ffn2_weights) + self.check(layer, up_gate_proj_weights, down_proj_weights) - for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = self.added_weight_attrs[idx] scale_name = self.added_scale_attrs[idx] @@ -537,14 +537,14 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): [0, 2, 1]).contiguous() create_and_set_parameter(layer, scale_name, quanted_weight_scale) - def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights): + def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights): """ check layer is valid for this method """ - assert ffn1_weights[0].shape == [ + assert up_gate_proj_weights[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2 ] - assert ffn2_weights[0].shape == [ + assert down_proj_weights[0].shape == [ layer.moe_intermediate_size, layer.hidden_size ] @@ -563,8 +563,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): num_local_experts = layer.num_local_experts moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - E, N1, _ = layer.moe_ffn1_weight.shape - N2 = layer.moe_ffn2_weight.shape[1] + E, N1, _ = layer.up_gate_proj_weight.shape + N2 = layer.down_proj_weight.shape[1] topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, @@ -605,10 +605,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x_q, - layer.moe_ffn1_weight.view(paddle.float8_e4m3fn), + layer.up_gate_proj_weight.view(paddle.float8_e4m3fn), intermediate_cache1, x_scale, - layer.moe_ffn1_weight_scale, + layer.up_gate_proj_weight_scale, None, sorted_token_ids, expert_ids, @@ -619,17 +619,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): K=hidden_size, stride_am=x_q.strides[0], stride_ak=x_q.strides[1], - stride_be=layer.moe_ffn1_weight.strides[0], - stride_bk=layer.moe_ffn1_weight.strides[2], - stride_bn=layer.moe_ffn1_weight.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[2], + stride_bn=layer.up_gate_proj_weight.strides[1], stride_cm=intermediate_cache1.strides[0], stride_cn=intermediate_cache1.strides[1], # stride_asm=x_scale.strides[0], # only used in blockwise fp8 stride_ask=x_scale.strides[1], # only used in blockwise fp8 - stride_bse=layer.moe_ffn1_weight_scale.strides[0], - stride_bsk=layer.moe_ffn1_weight_scale.strides[2], - stride_bsn=layer.moe_ffn1_weight_scale.strides[1], + stride_bse=layer.up_gate_proj_weight_scale.strides[0], + stride_bsk=layer.up_gate_proj_weight_scale.strides[2], + stride_bsn=layer.up_gate_proj_weight_scale.strides[1], group_n=self.quant_config.weight_block_size[1], group_k=self.quant_config.weight_block_size[0], # Meta-parameters @@ -656,10 +656,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x_q, - layer.moe_ffn2_weight.view(paddle.float8_e4m3fn), + layer.down_proj_weight.view(paddle.float8_e4m3fn), intermediate_cache3, x_scale, - layer.moe_ffn2_weight_scale, + layer.down_proj_weight_scale, topk_weights, sorted_token_ids, expert_ids, @@ -670,16 +670,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): K=moe_intermediate_size, stride_am=x_q.strides[0], stride_ak=x_q.strides[1], - stride_be=layer.moe_ffn2_weight.strides[0], - stride_bk=layer.moe_ffn2_weight.strides[2], - stride_bn=layer.moe_ffn2_weight.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[2], + stride_bn=layer.down_proj_weight.strides[1], stride_cm=intermediate_cache3.strides[0], stride_cn=intermediate_cache3.strides[1], stride_asm=x_scale.strides[0], # only used in blockwise fp8 stride_ask=x_scale.strides[1], # only used in blockwise fp8 - stride_bse=layer.moe_ffn2_weight_scale.strides[0], - stride_bsk=layer.moe_ffn2_weight_scale.strides[2], - stride_bsn=layer.moe_ffn2_weight_scale.strides[1], + stride_bse=layer.down_proj_weight_scale.strides[0], + stride_bsk=layer.down_proj_weight_scale.strides[2], + stride_bsn=layer.down_proj_weight_scale.strides[1], group_n=self.quant_config.weight_block_size[1], group_k=self.quant_config.weight_block_size[0], # Meta-parameters diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index ca81b149e..5ec8c31af 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -41,16 +41,16 @@ class Wint2MoeMethod(QuantMethodBase): """ pass - def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights): + def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights): """ check layer is valid for this method """ assert len( - ffn1_weights - ) == layer.num_local_experts, "ffn1_weights length should be equal to num_local_experts." + up_gate_proj_weights + ) == layer.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts." assert len( - ffn2_weights - ) == layer.num_local_experts, "ffn2_weights length should be equal to num_local_experts." + down_proj_weights + ) == layer.num_local_experts, "down_proj_weights length should be equal to num_local_experts." def create_weights(self, layer: nn.Layer, state_dict): """ @@ -78,96 +78,96 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): """ Paddle cutlass process prequanted weights. """ - ffn1_expert_weight_key = layer.weight_key_map.get( - "ffn1_expert_weight_key", None) - ffn2_expert_weight_key = layer.weight_key_map.get( - "ffn2_expert_weight_key", None) - ffn1_expert_weight_scale_key = layer.weight_key_map.get( - "ffn1_expert_weight_scale_key", None) - ffn2_expert_weight_scale_key = layer.weight_key_map.get( - "ffn2_expert_weight_scale_key", None) - ffn1_expert_super_scales_key = layer.weight_key_map.get( - "ffn1_expert_super_scales_key", None) - ffn2_expert_super_scales_key = layer.weight_key_map.get( - "ffn2_expert_super_scales_key", None) - ffn1_expert_code_scale_key = layer.weight_key_map.get( - "ffn1_expert_code_scale_key", None) - ffn2_expert_code_scale_key = layer.weight_key_map.get( - "ffn2_expert_code_scale_key", None) - ffn1_expert_code_zp_key = layer.weight_key_map.get( - "ffn1_expert_code_zp_key", None) - ffn2_expert_code_zp_key = layer.weight_key_map.get( - "ffn2_expert_code_zp_key", None) + up_gate_proj_expert_weight_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_key", None) + down_proj_expert_weight_key = layer.weight_key_map.get( + "down_proj_expert_weight_key", None) + up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get( + "up_gate_proj_expert_weight_scale_key", None) + down_proj_expert_weight_scale_key = layer.weight_key_map.get( + "down_proj_expert_weight_scale_key", None) + up_gate_proj_expert_super_scales_key = layer.weight_key_map.get( + "up_gate_proj_expert_super_scales_key", None) + down_proj_expert_super_scales_key = layer.weight_key_map.get( + "down_proj_expert_super_scales_key", None) + up_gate_proj_expert_code_scale_key = layer.weight_key_map.get( + "up_gate_proj_expert_code_scale_key", None) + down_proj_expert_code_scale_key = layer.weight_key_map.get( + "down_proj_expert_code_scale_key", None) + up_gate_proj_expert_code_zp_key = layer.weight_key_map.get( + "up_gate_proj_expert_code_zp_key", None) + down_proj_expert_code_zp_key = layer.weight_key_map.get( + "down_proj_expert_code_zp_key", None) - ffn1_weights, ffn2_weights = layer.load_experts_weight( - state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key) - # self.check(layer, ffn1_weights, ffn2_weights) + up_gate_proj_weights, down_proj_weights = layer.load_experts_weight( + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key) + # self.check(layer, up_gate_proj_weights, down_proj_weights) - ffn1_weight_scale = [] - ffn2_weight_scale = [] - ffn1_super_scales = [] - ffn2_super_scales = [] - ffn1_code_scale = [] - ffn2_code_scale = [] - ffn1_code_zp = [] - ffn2_code_zp = [] + up_gate_proj_weight_scale = [] + down_proj_weight_scale = [] + up_gate_proj_super_scales = [] + down_proj_super_scales = [] + up_gate_proj_code_scale = [] + down_proj_code_scale = [] + up_gate_proj_code_zp = [] + down_proj_code_zp = [] for i in range(layer.num_experts): expert_idx = layer.expert_id_offset + i - ffn1_weight_scale.append( + up_gate_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn1_expert_weight_scale_key.format(expert_idx)))) - ffn2_weight_scale.append( + up_gate_proj_expert_weight_scale_key.format(expert_idx)))) + down_proj_weight_scale.append( get_tensor( state_dict.pop( - ffn2_expert_weight_scale_key.format(expert_idx)))) - ffn1_super_scales.append( + down_proj_expert_weight_scale_key.format(expert_idx)))) + up_gate_proj_super_scales.append( get_tensor( state_dict.pop( - ffn1_expert_super_scales_key.format(expert_idx)))) - ffn2_super_scales.append( + up_gate_proj_expert_super_scales_key.format(expert_idx)))) + down_proj_super_scales.append( get_tensor( state_dict.pop( - ffn2_expert_super_scales_key.format(expert_idx)))) - ffn1_code_scale.append( + down_proj_expert_super_scales_key.format(expert_idx)))) + up_gate_proj_code_scale.append( get_tensor( state_dict.pop( - ffn1_expert_code_scale_key.format(expert_idx)))) - ffn2_code_scale.append( + up_gate_proj_expert_code_scale_key.format(expert_idx)))) + down_proj_code_scale.append( get_tensor( state_dict.pop( - ffn2_expert_code_scale_key.format(expert_idx)))) - ffn1_code_zp.append( + down_proj_expert_code_scale_key.format(expert_idx)))) + up_gate_proj_code_zp.append( get_tensor( state_dict.pop( - ffn1_expert_code_zp_key.format(expert_idx)))) - ffn2_code_zp.append( + up_gate_proj_expert_code_zp_key.format(expert_idx)))) + down_proj_code_zp.append( get_tensor( state_dict.pop( - ffn2_expert_code_zp_key.format(expert_idx)))) + down_proj_expert_code_zp_key.format(expert_idx)))) - ffn1_weight = paddle.stack(ffn1_weights, axis=0) - ffn2_weight = paddle.stack(ffn2_weights, axis=0) - ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0) - ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0) - ffn1_super_scales = paddle.stack(ffn1_super_scales, axis=0) - ffn2_super_scales = paddle.stack(ffn2_super_scales, axis=0) - ffn1_code_scale = paddle.stack(ffn1_code_scale, axis=0) - ffn2_code_scale = paddle.stack(ffn2_code_scale, axis=0) - ffn1_code_zp = paddle.stack(ffn1_code_zp, axis=0) - ffn2_code_zp = paddle.stack(ffn2_code_zp, axis=0) + up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_weight = paddle.stack(down_proj_weights, axis=0) + up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0) + down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0) + up_gate_proj_super_scales = paddle.stack(up_gate_proj_super_scales, axis=0) + down_proj_super_scales = paddle.stack(down_proj_super_scales, axis=0) + up_gate_proj_code_scale = paddle.stack(up_gate_proj_code_scale, axis=0) + down_proj_code_scale = paddle.stack(down_proj_code_scale, axis=0) + up_gate_proj_code_zp = paddle.stack(up_gate_proj_code_zp, axis=0) + down_proj_code_zp = paddle.stack(down_proj_code_zp, axis=0) name_tensor_map = { - "moe_ffn1_weight": ffn1_weight, - "moe_ffn2_weight": ffn2_weight, - "moe_ffn1_weight_scale": ffn1_weight_scale, - "moe_ffn2_weight_scale": ffn2_weight_scale, - "moe_ffn1_super_scales": ffn1_super_scales, - "moe_ffn2_super_scales": ffn2_super_scales, - "moe_ffn1_code_scale": ffn1_code_scale, - "moe_ffn2_code_scale": ffn2_code_scale, - "moe_ffn1_code_zp": ffn1_code_zp, - "moe_ffn2_code_zp": ffn2_code_zp + "up_gate_proj_weight": up_gate_proj_weight, + "down_proj_weight": down_proj_weight, + "up_gate_proj_weight_scale": up_gate_proj_weight_scale, + "down_proj_weight_scale": down_proj_weight_scale, + "up_gate_proj_super_scales": up_gate_proj_super_scales, + "down_proj_super_scales": down_proj_super_scales, + "up_gate_proj_code_scale": up_gate_proj_code_scale, + "down_proj_code_scale": down_proj_code_scale, + "up_gate_proj_code_zp": up_gate_proj_code_zp, + "down_proj_code_zp": down_proj_code_zp } for name, tensor in name_tensor_map.items(): create_and_set_parameter(layer, name, tensor) @@ -200,7 +200,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): x, gate_out, layer.gate_correction_bias, - (layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale") + (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), # if set, permute_input will be int8_t layer.top_k, False, @@ -210,17 +210,17 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2( permute_input, token_nums_per_expert, - layer.moe_ffn1_weight, - layer.moe_ffn2_weight, + layer.up_gate_proj_weight, + layer.down_proj_weight, None, - layer.moe_ffn1_super_scales, - layer.moe_ffn2_super_scales, - layer.moe_ffn1_weight_scale, - layer.moe_ffn1_code_scale, - layer.moe_ffn1_code_zp, - layer.moe_ffn2_weight_scale, - layer.moe_ffn2_code_scale, - layer.moe_ffn2_code_zp, + layer.up_gate_proj_super_scales, + layer.down_proj_super_scales, + layer.up_gate_proj_weight_scale, + layer.up_gate_proj_code_scale, + layer.up_gate_proj_code_zp, + layer.down_proj_weight_scale, + layer.down_proj_code_scale, + layer.down_proj_code_zp, False, ) @@ -271,7 +271,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): ) num_tokens, K = x.shape - E, _, N = layer.moe_ffn1_weight.shape + E, _, N = layer.up_gate_proj_weight.shape M = num_tokens top_k = topk_ids.shape[1] @@ -308,12 +308,12 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): moe_wint2_ffn_kernel[grid]( x, - layer.moe_ffn1_weight, + layer.up_gate_proj_weight, intermediate_cache1, - layer.moe_ffn1_weight_scale, - layer.moe_ffn1_super_scales, - layer.moe_ffn1_code_scale, - layer.moe_ffn1_code_zp, + layer.up_gate_proj_weight_scale, + layer.up_gate_proj_super_scales, + layer.up_gate_proj_code_scale, + layer.up_gate_proj_code_zp, topk_weights, sorted_token_ids, expert_ids, @@ -321,7 +321,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): num_valid_tokens, max_possible_num_post_padded, # Matrix dimensions - N=layer.moe_ffn1_weight.shape[-1], + N=layer.up_gate_proj_weight.shape[-1], K=x.shape[-1], # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -329,15 +329,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): # (A has M rows). stride_am=x.strides[0], stride_ak=x.strides[1], - stride_be=layer.moe_ffn1_weight.strides[0], - stride_bk=layer.moe_ffn1_weight.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], stride_bn=1, stride_cm=intermediate_cache1.strides[-2], stride_cn=1, - stride_bse=layer.moe_ffn1_weight_scale.strides[0], - stride_bsk=layer.moe_ffn1_weight_scale.strides[1], + stride_bse=layer.up_gate_proj_weight_scale.strides[0], + stride_bsk=layer.up_gate_proj_weight_scale.strides[1], stride_bsn=1, - stride_bce=layer.moe_ffn1_code_scale.strides[0], + stride_bce=layer.up_gate_proj_code_scale.strides[0], stride_bck=1, stride_bcn=1, BLOCK_SIZE_M=config["BLOCK_SIZE_M"], @@ -361,17 +361,17 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): } grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) * - ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), ) + ceil_div(layer.down_proj_weight.shape[-1], config["BLOCK_SIZE_N"]), ) moe_wint2_ffn_kernel[grid]( intermediate_cache2, - layer.moe_ffn2_weight, + layer.down_proj_weight, intermediate_cache3, - layer.moe_ffn2_weight_scale, - layer.moe_ffn2_super_scales, - layer.moe_ffn2_code_scale, - layer.moe_ffn2_code_zp, + layer.down_proj_weight_scale, + layer.down_proj_super_scales, + layer.down_proj_code_scale, + layer.down_proj_code_zp, topk_weights, sorted_token_ids, expert_ids, @@ -379,7 +379,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): num_valid_tokens, max_possible_num_post_padded, # Matrix dimensions - N=layer.moe_ffn2_weight.shape[-1], + N=layer.down_proj_weight.shape[-1], K=intermediate_cache2.shape[-1], # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -387,15 +387,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): # (A has M rows). stride_am=intermediate_cache2.strides[0], stride_ak=1, - stride_be=layer.moe_ffn2_weight.strides[0], - stride_bk=layer.moe_ffn2_weight.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], stride_bn=1, stride_cm=intermediate_cache3.strides[-2], stride_cn=1, - stride_bse=layer.moe_ffn2_weight_scale.strides[0], - stride_bsk=layer.moe_ffn2_weight_scale.strides[1], + stride_bse=layer.down_proj_weight_scale.strides[0], + stride_bsk=layer.down_proj_weight_scale.strides[1], stride_bsn=1, - stride_bce=layer.moe_ffn2_code_scale.strides[0], + stride_bce=layer.down_proj_code_scale.strides[0], stride_bck=1, stride_bcn=1, BLOCK_SIZE_M=config["BLOCK_SIZE_M"], diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py index 3bd99ce17..6f74acdff 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py @@ -38,14 +38,14 @@ class XPUMoEMethod(MoEMethodBase): Paddle cutlass create weight process. """ # bf16 - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - for weights in [ffn1_weights, ffn2_weights]: + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + for weights in [up_gate_proj_weights, down_proj_weights]: for idx, weight in enumerate(weights): weights[idx] = weight.transpose([1, 0]) - stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0) - stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0) + stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) + stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) for idx, weight_tensor in enumerate( - [stacked_ffn1_weights, stacked_ffn2_weights]): + [stacked_up_gate_proj_weights, stacked_down_proj_weights]): weight_name = self.added_weight_attrs[idx] setattr( layer, weight_name, @@ -71,13 +71,13 @@ class XPUMoEMethod(MoEMethodBase): x, layer.gate_weight.transpose([1, 0]), layer.gate_correction_bias, - layer.moe_ffn1_weight, - layer.moe_ffn2_weight, - None, # ffn1 bias - None, # ffn2 bias - None, # ffn1 scale - None, # ffn2 scale - None, # ffn1_in_scale + layer.up_gate_proj_weight, + layer.down_proj_weight, + None, # up_gate_proj bias + None, # down_proj bias + None, # up_gate_proj scale + None, # down_proj scale + None, # up_gate_proj_in_scale "", # moe_quant_type layer.top_k, False, # moe group, used in deepseek @@ -129,20 +129,20 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase): """ Paddle cutlass create weight process. """ - ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) - assert len(ffn1_weights) == layer.num_local_experts - assert len(ffn2_weights) == layer.num_local_experts - assert ffn1_weights[0].shape == [ + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + assert len(up_gate_proj_weights) == layer.num_local_experts + assert len(down_proj_weights) == layer.num_local_experts + assert up_gate_proj_weights[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2 ] - assert ffn2_weights[0].shape == [ + assert down_proj_weights[0].shape == [ layer.moe_intermediate_size, layer.hidden_size ] - added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] - added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"] + added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] + added_scale_attrs = ["up_gate_proj_weight_scale", "down_proj_weight_scale"] - for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): weight_name = added_weight_attrs[idx] scale_name = added_scale_attrs[idx] @@ -189,16 +189,16 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase): x, layer.gate_weight.transpose([1, 0]), layer.gate_correction_bias, - layer.moe_ffn1_weight, - layer.moe_ffn2_weight, - None, # ffn1 bias - None, # ffn2 bias - (layer.moe_ffn1_weight_scale - if hasattr(layer, "moe_ffn1_weight_scale") else None), - (layer.moe_ffn2_weight_scale - if hasattr(layer, "moe_ffn2_weight_scale") else None), - (layer.moe_ffn2_in_scale - if hasattr(layer, "moe_ffn2_in_scale") else None), + layer.up_gate_proj_weight, + layer.down_proj_weight, + None, # up_gate_proj bias + None, # down_proj bias + (layer.up_gate_proj_weight_scale + if hasattr(layer, "up_gate_proj_weight_scale") else None), + (layer.down_proj_weight_scale + if hasattr(layer, "down_proj_weight_scale") else None), + (layer.down_proj_in_scale + if hasattr(layer, "down_proj_in_scale") else None), self.moe_quant_type, layer.top_k, False, # moe group, used in deepseek diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 2a6a8b4a1..5b9b72656 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -145,13 +145,13 @@ class FusedMoE(nn.Layer): shape=gate_correction_bias_shape, dtype="float32", ) - ffn1_output_dim = self.moe_intermediate_size * 2 + up_gate_proj_output_dim = self.moe_intermediate_size * 2 if self.moe_quant_type in ["fp8", "wint8"]: - ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size] - ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size] + up_gate_proj_weight_shape = [self.num_local_experts, up_gate_proj_output_dim, self.hidden_size] + down_proj_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size] else: - ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim] - ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size] + up_gate_proj_weight_shape = [self.num_local_experts, self.hidden_size, up_gate_proj_output_dim] + down_proj_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size] # Create parameters if self.moe_quant_type == "fp8": @@ -161,15 +161,15 @@ class FusedMoE(nn.Layer): self.weight_dtype = "int8" self.init_weight_only_scale() - # FFN1 parameters - self.moe_ffn1_weight = self.create_parameter( - shape=ffn1_weight_shape, + # up_gate_proj parameters + self.up_gate_proj_weight = self.create_parameter( + shape=up_gate_proj_weight_shape, dtype=self.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ) - # FFN2 parameters - self.moe_ffn2_weight = self.create_parameter( - shape=ffn2_weight_shape, + # down_proj parameters + self.down_proj_weight = self.create_parameter( + shape=down_proj_weight_shape, dtype=self.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ) @@ -178,44 +178,44 @@ class FusedMoE(nn.Layer): """ Initialize the weight scale. """ - self.moe_ffn1_weight_scale = self.create_parameter( + self.up_gate_proj_weight_scale = self.create_parameter( shape=[self.num_local_experts, self.moe_intermediate_size * 2], dtype=self._dtype, ) - self.moe_ffn2_weight_scale = self.create_parameter( + self.down_proj_weight_scale = self.create_parameter( shape=[self.num_local_experts, self.hidden_size], dtype=self._dtype, ) def load_experts_weight(self, state_dict: dict, - ffn1_expert_weight_key: str, - ffn2_expert_weight_key: str): + up_gate_proj_expert_weight_key: str, + down_proj_expert_weight_key: str): """ Load experts weight from state_dict. Args: state_dict (dict): The state_dict of model. - ffn1_expert_weight_key (str): The key of ffn1 expert weight. - ffn2_expert_weight_key (str): The key of ffn2 expert weight. + up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight. + down_proj_expert_weight_key (str): The key of down_proj expert weight. """ - ffn1_weights = [] - ffn2_weights = [] - is_ffn_merged = ffn1_expert_weight_key.format( + up_gate_proj_weights = [] + down_proj_weights = [] + is_ffn_merged = up_gate_proj_expert_weight_key.format( self.expert_id_offset) in state_dict if is_ffn_merged: for i in range(self.num_local_experts): expert_idx = self.expert_id_offset + i - ffn1_weights.append( + up_gate_proj_weights.append( get_tensor( state_dict.pop( - ffn1_expert_weight_key.format(expert_idx)))) - ffn2_weights.append( + up_gate_proj_expert_weight_key.format(expert_idx)))) + down_proj_weights.append( get_tensor( state_dict.pop( - ffn2_expert_weight_key.format(expert_idx)))) + down_proj_expert_weight_key.format(expert_idx)))) else: - gate_expert_weight_key = ffn1_expert_weight_key.replace( + gate_expert_weight_key = up_gate_proj_expert_weight_key.replace( "up_gate_proj", "gate_proj") - up_expert_weight_key = ffn1_expert_weight_key.replace( + up_expert_weight_key = up_gate_proj_expert_weight_key.replace( "up_gate_proj", "up_proj") for j in range(self.num_local_experts): expert_idx = self.expert_id_offset + j @@ -223,12 +223,12 @@ class FusedMoE(nn.Layer): state_dict.pop(gate_expert_weight_key.format(expert_idx))) up = get_tensor( state_dict.pop(up_expert_weight_key.format(expert_idx))) - ffn1_weights.append(paddle.concat([gate, up], axis=-1)) - ffn2_weights.append( + up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1)) + down_proj_weights.append( get_tensor( state_dict.pop( - ffn2_expert_weight_key.format(expert_idx)))) - return ffn1_weights, ffn2_weights + down_proj_expert_weight_key.format(expert_idx)))) + return up_gate_proj_weights, down_proj_weights def extract_moe_ffn_weights(self, state_dict: dict): """ @@ -239,30 +239,30 @@ class FusedMoE(nn.Layer): Returns: tuple: A tuple containing two lists: - - ffn1_weights: List of tensors for first FFN layer weights - - ffn2_weights: List of tensors for second FFN layer weights + - up_gate_proj_weights: List of tensors for first FFN layer weights + - down_proj_weights: List of tensors for second FFN layer weights Raises: AssertionError: If required weight keys are missing or number of weights doesn't match number of local experts. """ - ffn1_expert_weight_key = self.weight_key_map.get( - "ffn1_expert_weight_key", None) - ffn2_expert_weight_key = self.weight_key_map.get( - "ffn2_expert_weight_key", None) - assert ffn1_expert_weight_key is not None, "ffn1_expert_weight_key should not be none." - assert ffn2_expert_weight_key is not None, "ffn2_expert_weight_key should not be none." + up_gate_proj_expert_weight_key = self.weight_key_map.get( + "up_gate_proj_expert_weight_key", None) + down_proj_expert_weight_key = self.weight_key_map.get( + "down_proj_expert_weight_key", None) + assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none." + assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none." - ffn1_weights, ffn2_weights = self.load_experts_weight( - state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key) + up_gate_proj_weights, down_proj_weights = self.load_experts_weight( + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key) assert len( - ffn1_weights - ) == self.num_local_experts, "ffn1_weights length should be equal to num_local_experts." + up_gate_proj_weights + ) == self.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts." assert len( - ffn2_weights - ) == self.num_local_experts, "ffn2_weights length should be equal to num_local_experts." + down_proj_weights + ) == self.num_local_experts, "down_proj_weights length should be equal to num_local_experts." - return ffn1_weights, ffn2_weights + return up_gate_proj_weights, down_proj_weights def extract_gate_correction_bias(self, gate_correction_bias_key, state_dict): diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index 5506d2faa..80a8835ea 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -46,11 +46,11 @@ class ParallelEHProjection(nn.Layer): prefix (str): full name of the layer in the state dict """ super(ParallelEHProjection, self).__init__() - self.linear_weight_key = prefix + ".weight" + self.weight_key = prefix + ".weight" if with_bias: - self.linear_bias_key = prefix + ".bias" + self.bias_key = prefix + ".bias" else: - self.linear_bias_key = None + self.bias_key = None self.use_ep = fd_config.parallel_config.use_ep self.column_cut = True @@ -66,26 +66,26 @@ class ParallelEHProjection(nn.Layer): else: if self.column_cut: need_gather = True - self.out_linear = ColumnParallelLinear( + self.linear = ColumnParallelLinear( embedding_dim, num_embeddings, mp_group=fleet.get_hybrid_communicate_group(). get_model_parallel_group(), weight_attr=None, has_bias=True - if self.linear_bias_key is not None else False, + if self.bias_key is not None else False, gather_output=need_gather, fuse_matmul_bias=False, # False diff更小 ) else: - self.out_linear = RowParallelLinear( + self.linear = RowParallelLinear( embedding_dim, num_embeddings, mp_group=fleet.get_hybrid_communicate_group(). get_model_parallel_group(), weight_attr=None, has_bias=True - if self.linear_bias_key is not None else False, + if self.bias_key is not None else False, input_is_parallel=False, fuse_matmul_bias=False, # False diff更小 ) @@ -100,20 +100,20 @@ class ParallelEHProjection(nn.Layer): if self.use_ep: self.weight.set_value( - get_tensor(state_dict.pop(self.linear_weight_key)).astype( + get_tensor(state_dict.pop(self.weight_key)).astype( paddle.get_default_dtype())) else: weight_tensor = get_tensor( - state_dict.pop(self.linear_weight_key)).astype( + state_dict.pop(self.weight_key)).astype( paddle.get_default_dtype()) - if self.out_linear.weight.shape != weight_tensor.shape: + if self.linear.weight.shape != weight_tensor.shape: weight_tensor = weight_tensor.transpose([1, 0]) - self.out_linear.weight.set_value(weight_tensor) + self.linear.weight.set_value(weight_tensor) - if self.linear_bias_key is not None: - bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype( + if self.bias_key is not None: + bias = get_tensor(state_dict.pop(self.bias_key)).astype( paddle.get_default_dtype()) - self.out_linear.bias.set_value(bias) + self.linear.bias.set_value(bias) def forward(self, input): """ @@ -129,5 +129,5 @@ class ParallelEHProjection(nn.Layer): if self.use_ep: logits = paddle.matmul(logits, self.weight) else: - logits = self.out_linear(logits) + logits = self.linear(logits) return logits diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 9b16830b6..c91e74173 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -43,7 +43,7 @@ class RMSNorm(nn.Layer): hidden_size: int, eps: float = 1e-5, prefix: str = "", - linear_bias: paddle.Tensor = None, + bias: paddle.Tensor = None, quant_scale: float = None, begin_norm_axis: int = 1, ) -> None: @@ -57,7 +57,7 @@ class RMSNorm(nn.Layer): hidden_size (int) : size of hidden state. eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5. prefix(str,optional):The name of current layer. Defaults to "". - linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None. + bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None. quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization. begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1. @@ -78,7 +78,7 @@ class RMSNorm(nn.Layer): self.norm_func: Callable = fused_add_rms_norm else: self.norm_func: Callable = fused_rms_norm - self.linear_bias: Optional[paddle.Tensor] = linear_bias + self.bias: Optional[paddle.Tensor] = bias self.quant_scale: Optional[float] = quant_scale self._dtype: str = self._helper.get_default_dtype() self._norm_weight_dtype: str = self._dtype @@ -94,9 +94,9 @@ class RMSNorm(nn.Layer): Initialize the weights and biases. """ - self.ln_weight = None + self.weight = None if self.with_weight: - self.ln_weight = self.create_parameter( + self.weight = self.create_parameter( shape=[self.hidden_size], default_initializer=nn.initializer.Constant(value=1.0), dtype=self._norm_weight_dtype, @@ -115,7 +115,7 @@ class RMSNorm(nn.Layer): weight_tensor = paddle.cast( get_tensor(state_dict.pop(self.weight_key)), self._norm_weight_dtype) - self.ln_weight.set_value(weight_tensor) + self.weight.set_value(weight_tensor) def forward( self, @@ -139,18 +139,18 @@ class RMSNorm(nn.Layer): """ if current_platform.is_gcu(): if residual_input is None: - return rms_norm(x, self.ln_weight, self.eps) + return rms_norm(x, self.weight, self.eps) norm_out = self.norm_func( - x, residual_input, self.ln_weight, self.eps + x, residual_input, self.weight, self.eps ) else: norm_out = self.norm_func( x, - norm_weight=self.ln_weight, + norm_weight=self.weight, norm_bias=None, epsilon=self.eps, begin_norm_axis=self.begin_norm_axis, - bias=self.linear_bias, + bias=self.bias, residual=residual_input, quant_scale=-1 if self.quant_scale is None else self.quant_scale, quant_round_type=self.quant_round_type, @@ -174,7 +174,7 @@ class LayerNorm(nn.Layer): hidden_size: int, eps: float = 1e-5, prefix="", - linear_bias: paddle.Tensor = None, + bias: paddle.Tensor = None, quant_scale: float = None, with_bias: bool = False, ): @@ -189,7 +189,7 @@ class LayerNorm(nn.Layer): eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5. prefix (str): Unique name of the layer, used for naming internal attributes, you can give it any name you like. - linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None. + bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None. quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization. with_bias (bool):Whether to include bias or not. Defaults to False. Raises: @@ -212,7 +212,7 @@ class LayerNorm(nn.Layer): self.norm_func: Callable = paddle.nn.functional.layer_norm else: self.norm_func: Callable = fused_layer_norm - self.linear_bias: Optional[paddle.Tensor] = linear_bias + self.bias: Optional[paddle.Tensor] = bias self._dtype: str = self._helper.get_default_dtype() self._norm_weight_dtype: str = "float32" @@ -227,16 +227,16 @@ class LayerNorm(nn.Layer): Initialize the weights and biases. """ - self.ln_weight = None + self.weight = None if self.with_weight: - self.ln_weight = self.create_parameter( + self.weight = self.create_parameter( shape=[self.hidden_size], default_initializer=nn.initializer.Constant(value=1.0), dtype=self._norm_weight_dtype, ) - self.ln_bias = None + self.bias = None if self.with_bias: - self.ln_bias = self.create_parameter( + self.bias = self.create_parameter( shape=[self.hidden_size], is_bias=True, dtype=self._norm_weight_dtype, @@ -255,14 +255,14 @@ class LayerNorm(nn.Layer): weight_tensor = paddle.cast( get_tensor(state_dict.pop(self.weight_key)), self._norm_weight_dtype) - self.ln_weight.set_value(weight_tensor) + self.weight.set_value(weight_tensor) # bias if self.with_bias: bias_tensor = paddle.cast( get_tensor(state_dict.pop(self.bias_key)), self._norm_weight_dtype) - self.ln_bias.set_value(bias_tensor) + self.bias.set_value(bias_tensor) def forward( self, @@ -285,10 +285,10 @@ class LayerNorm(nn.Layer): operations (like linear transformation) on the `residual_input`. """ if current_platform.is_iluvatar(): - if self.ln_weight is None and self.ln_bias is None: + if self.weight is None and self.bias is None: out = x - if self.linear_bias is not None: - out += self.linear_bias + if self.bias is not None: + out += self.bias if residual_input is not None: out += residual_input return out, out @@ -303,8 +303,8 @@ class LayerNorm(nn.Layer): out = self.norm_func( x=y, normalized_shape=y.shape[1:], - weight=self.ln_weight, - bias=self.linear_bias, + weight=self.weight, + bias=self.bias, epsilon=self.eps, ) return out, y @@ -312,19 +312,19 @@ class LayerNorm(nn.Layer): out = self.norm_func( x=x, normalized_shape=x.shape[1:], - weight=self.ln_weight, - bias=self.linear_bias, + weight=self.weight, + bias=self.bias, epsilon=self.eps, ) return out else: norm_out = self.norm_func( x, - norm_weight=self.ln_weight, - norm_bias=self.ln_bias, + norm_weight=self.weight, + norm_bias=self.bias, epsilon=self.eps, begin_norm_axis=1, - bias=self.linear_bias, + bias=self.bias, residual=residual_input, quant_scale=-1 if self.quant_scale is None else self.quant_scale, quant_round_type=self.quant_round_type, diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index af061ce83..43f3bbc23 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -78,8 +78,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): self.quant_config = quant_config def create_weights(self, layer): - layer.linear_weight_shape.reverse() - layer.linear_weight_scale = layer.create_parameter( + layer.weight_shape.reverse() + layer.weight_scale = layer.create_parameter( shape=[ (layer.output_size + self.quant_config.weight_block_size[0] - 1) // self.quant_config.weight_block_size[0], @@ -95,8 +95,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): weight_tensor = weights.transpose([1, 0]) quanted_weight_tensor, weight_block_scale_tensor = ( per_block_cast_to_fp8(weight_tensor)) - layer.linear_weight.copy_(quanted_weight_tensor, False) - layer.linear_weight_scale.set_value(weight_block_scale_tensor) + layer.weight.copy_(quanted_weight_tensor, False) + layer.weight_scale.set_value(weight_block_scale_tensor) def process_prequanted_weights(self, layer, state_dict): """ @@ -106,10 +106,10 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) quant_weight = quant_weight.transpose([1, 0]).contiguous() - layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False) + layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False) weight_scale = weight_scale.transpose([1, 0]) - layer.linear_weight_scale.set_value(weight_scale) + layer.weight_scale.set_value(weight_scale) def apply(self, layer, x): x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding( @@ -119,9 +119,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm deep_gemm.gemm_fp8_fp8_bf16_nt( (x, x_scale_tensor), - (layer.linear_weight, layer.linear_weight_scale), + (layer.weight, layer.weight_scale), linear_out, ) if layer.with_bias: - linear_out = paddle.add(linear_out, layer.linear_bias) + linear_out = paddle.add(linear_out, layer.bias) return linear_out diff --git a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py index 99a8562b8..e2845af36 100644 --- a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py @@ -96,7 +96,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase): act_scale = get_tensor(state_dict.pop(layer.act_scale_key)) quant_weight = quant_weight.transpose([1, 0]).contiguous() - layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False) + layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False) self.act_scale = act_scale.item() self.total_scale = (act_scale * weight_scale).item() @@ -118,7 +118,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase): linear_out = cutlass_fp8_fp8_half_gemm_fused( fp8_x, - layer.linear_weight, + layer.weight, transpose_x=False, transpose_y=True, bias=None, diff --git a/fastdeploy/model_executor/layers/quantization/w4afp8.py b/fastdeploy/model_executor/layers/quantization/w4afp8.py index 49453c553..0785f4ab9 100644 --- a/fastdeploy/model_executor/layers/quantization/w4afp8.py +++ b/fastdeploy/model_executor/layers/quantization/w4afp8.py @@ -63,8 +63,8 @@ class W4AFP8LinearMethod(QuantMethodBase): self.quant_config = quant_config def create_weights(self, layer): - layer.linear_weight_shape.reverse() - layer.linear_weight_shape[0] //= 2 + layer.weight_shape.reverse() + layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" pass @@ -77,16 +77,16 @@ class W4AFP8LinearMethod(QuantMethodBase): scale_dtype="float16", )) weight_scale_tensor = paddle.view(weight_scale_tensor, layer._dtype) - layer.linear_weight.set_value(quanted_weight_tensor) - layer.linear_weight_scale.set_value(weight_scale_tensor) + layer.weight.set_value(quanted_weight_tensor) + layer.weight_scale.set_value(weight_scale_tensor) def apply(self, layer, x): linear_out = fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16( x, - layer.linear_weight, - layer.linear_weight_scale, + layer.weight, + layer.weight_scale, zero_points=None, - bias=layer.linear_bias if layer.add_bias else None, + bias=layer.bias if layer.add_bias else None, out_scale=self.quant_config.weight_scale_dict.get(layer.prefix + ".weight_scale") / (self.quant_config.act_scale_dict.get(layer.prefix + diff --git a/fastdeploy/model_executor/layers/quantization/w8a8.py b/fastdeploy/model_executor/layers/quantization/w8a8.py index 845421018..0d86789e0 100644 --- a/fastdeploy/model_executor/layers/quantization/w8a8.py +++ b/fastdeploy/model_executor/layers/quantization/w8a8.py @@ -69,7 +69,7 @@ class W8A8LinearMethod(QuantMethodBase): self.smooth_quant_method = SmoothQuantLinearMethod(quant_config) def create_weights(self, layer): - layer.linear_weight_shape.reverse() + layer.weight_shape.reverse() layer.weight_dtype = "int8" if self.quant_config.use_smooth_quant: self.smooth_quant_method.create_weights(layer) @@ -101,21 +101,21 @@ class W8A8LinearMethod(QuantMethodBase): if self.skip_quant: logger.debug(f"{layer.prefix} skip quant") weight_tensor = weights.cast(layer._dtype) - layer.linear_weight.set_value(weight_tensor) + layer.weight.set_value(weight_tensor) else: weight_tensor = weights.transpose([1, 0]) weight_tensor = paddle.cast(weight_tensor, "int8") - layer.linear_weight.set_value(weight_tensor) + layer.weight.set_value(weight_tensor) def apply(self, layer, x): if self.skip_quant: - linear_out = paddle.matmul(x, layer.linear_weight, False, True) + linear_out = paddle.matmul(x, layer.weight, False, True) return linear_out if self.quant_config.use_gemm_dequant: linear_out = fastdeploy.model_executor.ops.gpu.gemm_dequant( - x, layer.linear_weight, layer.linear_out_scale, layer._dtype) + x, layer.weight, layer.linear_out_scale, layer._dtype) else: - linear_out = paddle.matmul(x, layer.linear_weight, False, True) + linear_out = paddle.matmul(x, layer.weight, False, True) linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8( linear_out, layer.linear_out_scale, layer._dtype) return linear_out diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index c87ba7edd..0a48c60f3 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -77,12 +77,12 @@ class WeightOnlyConfig(QuantConfigBase): return GCUWeightOnlyLinearMethod(self) elif current_platform.is_dcu(): if isinstance(layer, FusedMoE): - from fastdeploy.model_executor.layers.backends import ( - DCUTritonWeightOnlyMoEMethod) + from fastdeploy.model_executor.layers.backends import \ + DCUTritonWeightOnlyMoEMethod return DCUTritonWeightOnlyMoEMethod(self) else: - from fastdeploy.model_executor.layers.backends import ( - DCUWeightOnlyLinearMethod) + from fastdeploy.model_executor.layers.backends import \ + DCUWeightOnlyLinearMethod return DCUWeightOnlyLinearMethod(self) else: if isinstance(layer, FusedMoE): @@ -152,14 +152,14 @@ class WeightOnlyLinearMethod(QuantMethodBase): def create_weights(self, layer): # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. - linear_weight_scale_shape = [layer.linear_weight_shape[1]] + weight_scale_shape = [layer.weight_shape[1]] - layer.linear_weight_shape.reverse() + layer.weight_shape.reverse() if self.quant_config.name() == "wint4": - layer.linear_weight_shape[0] //= 2 + layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" - layer.linear_weight_scale = layer.create_parameter( - shape=linear_weight_scale_shape, + layer.weight_scale = layer.create_parameter( + shape=weight_scale_shape, dtype=layer._dtype, is_bias=False, ) @@ -171,9 +171,9 @@ class WeightOnlyLinearMethod(QuantMethodBase): def apply(self, layer, x): linear_out = weight_only_linear( x, - weight=layer.linear_weight, - bias=layer.linear_bias if layer.add_bias else None, - weight_scale=layer.linear_weight_scale, + weight=layer.weight, + bias=layer.bias if layer.add_bias else None, + weight_scale=layer.weight_scale, weight_dtype="int8" if self.quant_config.name() == "wint8" else "int4", arch=self.quant_config.weight_only_linear_arch, @@ -204,8 +204,8 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): """ quant_weight = get_tensor(state_dict.pop(layer.weight_key)) weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) - layer.linear_weight.set_value(quant_weight) - layer.linear_weight_scale.set_value( + layer.weight.set_value(quant_weight) + layer.weight_scale.set_value( weight_scale.astype(paddle.get_default_dtype())) def process_loaded_weights(self, layer, weight) -> None: @@ -216,6 +216,6 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): arch=self.quant_config.weight_only_linear_arch, ) - layer.linear_weight.set_value(quanted_weight_tensor) - layer.linear_weight_scale.set_value( + layer.weight.set_value(quanted_weight_tensor) + layer.weight_scale.set_value( weight_scale_tensor.astype(paddle.get_default_dtype())) diff --git a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py index 4351ed138..34e2b7845 100644 --- a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py +++ b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py @@ -70,11 +70,11 @@ class WFP8AFP8LinearMethod(QuantMethodBase): def create_weights(self, layer): """ """ - layer.linear_weight_shape.reverse() + layer.weight_shape.reverse() layer.weight_dtype = "float8_e4m3fn" # TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func self.skip_quant = False - layer.linear_weight_scale = layer.create_parameter( + layer.weight_scale = layer.create_parameter( shape=[1], dtype="float32", is_bias=False, @@ -86,7 +86,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase): """ if self.skip_quant: weight_tensor = weights.cast(layer._dtype) - layer.linear_weight.set_value(weight_tensor) + layer.weight.set_value(weight_tensor) return if weights.dtype != paddle.float8_e4m3fn: self.use_per_token_if_dynamic = True @@ -95,22 +95,22 @@ class WFP8AFP8LinearMethod(QuantMethodBase): weight_tensor, use_per_token_if_dynamic=False, ) - layer.linear_weight.copy_(qweight, False) - layer.linear_weight_scale.set_value(weight_scale) + layer.weight.copy_(qweight, False) + layer.weight_scale.set_value(weight_scale) def apply(self, layer, x): """ """ if self.skip_quant: - linear_out = paddle.matmul(x, layer.linear_weight, False, True) + linear_out = paddle.matmul(x, layer.weight, False, True) return linear_out if self.use_per_token_if_dynamic: out_type = x.dtype a_q, a_scales = scaled_fp8_quant( x, use_per_token_if_dynamic=self.use_per_token_if_dynamic) - linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales, - layer.linear_weight_scale, out_type, - layer.linear_bias) + linear_out = cutlass_scaled_mm(a_q, layer.weight, a_scales, + layer.weight_scale, out_type, + layer.bias) else: raise NotImplementedError return linear_out diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index c9fcbd086..211bd93af 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -48,22 +48,22 @@ def load_ep_checkpoint(model_path: str, config.num_experts_start_offset, config.num_experts_start_offset + config.num_experts_per_rank, ): - ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" - ffn2_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight") + up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" + down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight") - ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" - ffn2_quant_key = ( + up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" + down_proj_quant_key = ( f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight") - ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" - ffn2_scale_key = ( + up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" + down_proj_scale_key = ( f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale") - num_local_ffn_keys.append(ffn1_key) - num_local_ffn_keys.append(ffn2_key) - num_local_ffn_keys.append(ffn1_quant_key) - num_local_ffn_keys.append(ffn2_quant_key) - num_local_ffn_keys.append(ffn1_scale_key) - num_local_ffn_keys.append(ffn2_scale_key) + num_local_ffn_keys.append(up_gate_proj_key) + num_local_ffn_keys.append(down_proj_key) + num_local_ffn_keys.append(up_gate_proj_quant_key) + num_local_ffn_keys.append(down_proj_quant_key) + num_local_ffn_keys.append(up_gate_proj_scale_key) + num_local_ffn_keys.append(down_proj_scale_key) for k in num_local_ffn_keys: if k in weight_list: diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index a9ac1d50b..c7f573772 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -61,7 +61,7 @@ class DeepSeekV3MLP(nn.Layer): ) -> None: super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( + self.up_gate_proj = MergedColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.up_gate_proj", input_size=fd_config.model_config.hidden_size, @@ -88,13 +88,13 @@ class DeepSeekV3MLP(nn.Layer): def load_state_dict(self, state_dict): """ """ - self.gate_up_proj.load_state_dict(state_dict) + self.up_gate_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict) def forward(self, x): """ """ - gate_up_out = self.gate_up_proj(x) + gate_up_out = self.up_gate_proj(x) act_out = self.act_fn(gate_up_out) down_out = self.down_proj(act_out) return down_out @@ -115,9 +115,9 @@ class DeepSeekV3MoE(nn.Layer): "gate_weight_key": f"{prefix}.gate.weight", "gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } @@ -528,7 +528,7 @@ class DeepSeekV3Model(nn.Layer): self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "deepseek_v3" - self.embeddings = VocabParallelEmbedding( + self.embed_tokens = VocabParallelEmbedding( fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, @@ -554,7 +554,7 @@ class DeepSeekV3Model(nn.Layer): """ Load model parameters from a given state dictionary. """ - self.embeddings.load_state_dict(state_dict) + self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") @@ -569,7 +569,7 @@ class DeepSeekV3Model(nn.Layer): ): """ """ - hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) residual = None for i in range(self.num_layers): diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 4ae0b3c18..3c8e0d8e5 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -23,6 +23,7 @@ import numpy as np import paddle from paddle import nn from paddleformers.transformers import PretrainedModel +from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.utils.log import logger from fastdeploy.config import FDConfig @@ -55,7 +56,7 @@ class Ernie4_5_MLP(nn.Layer): ) -> None: super().__init__() self.nranks = fd_config.parallel_config.tensor_parallel_size - self.gate_up_proj = MergedColumnParallelLinear( + self.up_gate_proj = MergedColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.up_gate_proj", input_size=fd_config.model_config.hidden_size, @@ -79,11 +80,11 @@ class Ernie4_5_MLP(nn.Layer): ) def load_state_dict(self, state_dict): - self.gate_up_proj.load_state_dict(state_dict) + self.up_gate_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict) def forward(self, hidden_states: paddle.Tensor): - gate_up_out = self.gate_up_proj(hidden_states) + gate_up_out = self.up_gate_proj(hidden_states) act_out = self.act_fn(gate_up_out) down_out = self.down_proj(act_out) return down_out @@ -104,17 +105,17 @@ class Ernie4_5_MoE(nn.Layer): f"{prefix}.gate.weight", "gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight", - "ffn1_expert_weight_scale_key": + "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale", - "ffn2_expert_weight_scale_key": + "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale", - "ffn1_expert_in_scale_key": + "up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale", - "ffn2_expert_in_scale_key": + "down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale", } elif moe_quant_type == "w4w2": @@ -123,25 +124,25 @@ class Ernie4_5_MoE(nn.Layer): f"{prefix}.gate.weight", "gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight", - "ffn1_expert_weight_scale_key": + "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale", - "ffn2_expert_weight_scale_key": + "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale", - "ffn1_expert_super_scales_key": + "up_gate_proj_expert_super_scales_key": f"{prefix}.experts.{{}}.up_gate_proj.super_scales", - "ffn2_expert_super_scales_key": + "down_proj_expert_super_scales_key": f"{prefix}.experts.{{}}.down_proj.super_scales", - "ffn1_expert_code_scale_key": + "up_gate_proj_expert_code_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.code_scale", - "ffn2_expert_code_scale_key": + "down_proj_expert_code_scale_key": f"{prefix}.experts.{{}}.down_proj.code_scale", - "ffn1_expert_code_zp_key": + "up_gate_proj_expert_code_zp_key": f"{prefix}.experts.{{}}.up_gate_proj.code_zp", - "ffn2_expert_code_zp_key": + "down_proj_expert_code_zp_key": f"{prefix}.experts.{{}}.down_proj.code_zp", } elif moe_quant_type == "tensor_wise_fp8" or ( @@ -152,17 +153,17 @@ class Ernie4_5_MoE(nn.Layer): f"{prefix}.gate.weight", "gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight", - "ffn1_expert_weight_scale_key": + "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale", - "ffn2_expert_weight_scale_key": + "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale", - "ffn1_expert_in_scale_key": + "up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale", - "ffn2_expert_in_scale_key": + "down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale", } else: @@ -171,9 +172,9 @@ class Ernie4_5_MoE(nn.Layer): f"{prefix}.gate.weight", "gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } @@ -271,7 +272,7 @@ class Ernie4_5_DecoderLayer(nn.Layer): prefix=f"{prefix}.self_attn", ) - if (fd_config.model_config.moe_num_experts is not None + if (getattr(fd_config.model_config, "moe_num_experts", None) is not None and layer_id >= fd_config.model_config.moe_layer_start_index): self.mlp = Ernie4_5_MoE( fd_config=fd_config, @@ -349,14 +350,14 @@ class Ernie4_5_Model(nn.Layer): self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "ernie" - self.embeddings = VocabParallelEmbedding( + self.embed_tokens = VocabParallelEmbedding( fd_config=fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, params_dtype=paddle.get_default_dtype(), prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens")) - self.hidden_layers = nn.LayerList([ + self.layers = nn.LayerList([ Ernie4_5_DecoderLayer( fd_config=fd_config, prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}") @@ -379,22 +380,22 @@ class Ernie4_5_Model(nn.Layer): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.embeddings.load_state_dict(state_dict) + self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") - self.hidden_layers[i].load_state_dict(state_dict) + self.layers[i].load_state_dict(state_dict) def forward( self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, ): - hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) residual = None for i in range(self.num_layers): - hidden_states, residual = self.hidden_layers[i](forward_meta, + hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) @@ -417,7 +418,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): """ super(Ernie4_5_MoeForCausalLM, self).__init__(fd_config) self.fd_config = fd_config - self.model = Ernie4_5_Model(fd_config=fd_config) + self.ernie = Ernie4_5_Model(fd_config=fd_config) self.ori_vocab_size = fd_config.model_config.ori_vocab_size @@ -444,10 +445,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.model.load_state_dict(state_dict) + self.ernie.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.out_linear.weight.set_value( - self.model.embeddings.word_embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) else: self.lm_head.load_state_dict(state_dict) @@ -468,14 +469,14 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): ) for i in range(self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.num_hidden_layers): - self.model.hidden_layers[i].mlp.fused_moe(fake_hidden_states) + self.ernie.layers[i].mlp.fused_moe(fake_hidden_states) def forward( self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, ): - hidden_states = self.model(ids_remove_padding=ids_remove_padding, + hidden_states = self.ernie(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) return hidden_states @@ -559,7 +560,7 @@ class Ernie4_5_PretrainedModel(PretrainedModel): ] @classmethod - def _get_tensor_parallel_mappings(cls, config, is_split=True): + def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True): """ get_tensor_parallel_mappings """ diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index 02a711c94..47dbee48f 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -263,9 +263,9 @@ class Ernie4_5_MTPModel(nn.Layer): super().__init__() self.num_layers = fd_config.model_config.num_hidden_layers - self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings + self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens - self.hidden_layers = nn.LayerList([ + self.layers = nn.LayerList([ Ernie4_5_DecoderLayer( fd_config=fd_config, prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.{i}") @@ -302,13 +302,13 @@ class Ernie4_5_MTPModel(nn.Layer): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - # self.embeddings.load_state_dict(state_dict) + # self.embed_tokens.load_state_dict(state_dict) self.enorm.load_state_dict(state_dict) self.hnorm.load_state_dict(state_dict) self.eh_proj.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") - self.hidden_layers[i].load_state_dict(state_dict) + self.layers[i].load_state_dict(state_dict) def forward( self, @@ -319,7 +319,7 @@ class Ernie4_5_MTPModel(nn.Layer): """ forward """ - inputs_embedding = self.embeddings( + inputs_embedding = self.embed_tokens( ids_remove_padding=ids_remove_padding) inputs_embedding = paddle.concat( [self.enorm(inputs_embedding), @@ -328,7 +328,7 @@ class Ernie4_5_MTPModel(nn.Layer): hidden_states = self.eh_proj(inputs_embedding) residual = None for i in range(self.num_layers): - hidden_states, residual = self.hidden_layers[i](forward_meta, + hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) @@ -349,7 +349,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): """ super(Ernie4_5_MTPForCausalLM, self).__init__(fd_config) self.fd_config = fd_config - self.model = Ernie4_5_MTPModel(fd_config=fd_config) + self.ernie = Ernie4_5_MTPModel(fd_config=fd_config) self.ori_vocab_size = fd_config.model_config.ori_vocab_size @@ -373,10 +373,10 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.model.load_state_dict(state_dict) + self.ernie.load_state_dict(state_dict) # if self.tie_word_embeddings: - # self.lm_head.out_linear.weight.set_value( - # self.model.embeddings.word_embeddings.weight.transpose([1, 0])) + # self.lm_head.linear.weight.set_value( + # self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) # else: # self.lm_head.load_state_dict(state_dict) @@ -400,7 +400,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): ) for i in range(self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.num_hidden_layers): - self.model.hidden_layers[i].mlp.fused_moe(fake_hidden_states) + self.ernie.layers[i].mlp.fused_moe(fake_hidden_states) def forward( self, @@ -411,7 +411,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): """ forward """ - hidden_states = self.model(ids_remove_padding, previous_hidden_states, + hidden_states = self.ernie(ids_remove_padding, previous_hidden_states, forward_meta) return hidden_states diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 622bf2801..f592c8abb 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -24,9 +24,10 @@ import numpy as np import paddle from paddle import nn from paddleformers.transformers import PretrainedModel +from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.utils.log import logger -from fastdeploy.config import FDConfig, ModelConfig +from fastdeploy.config import FDConfig from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce from fastdeploy.model_executor.graph_optimization.decorator import \ @@ -99,12 +100,12 @@ class Ernie4_5_VLMoE(nn.Layer): f"{prefix}.gate.weight", "gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } - self.mlp_text = FusedMoE( + self.text_fused_moe = FusedMoE( fd_config=fd_config, reduce_results=False, moe_intermediate_size=fd_config.model_config. @@ -116,9 +117,9 @@ class Ernie4_5_VLMoE(nn.Layer): moe_tag="Text", weight_key_map=weight_key_map, ) - self.mlp_text.extract_gate_correction_bias = self.extract_gate_correction_bias_text + self.text_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_text else: - self.mlp_text = Ernie4_5_VLMLP( + self.text_fused_moe = Ernie4_5_VLMLP( fd_config=fd_config, intermediate_size=fd_config.model_config.intermediate_size, prefix=f"{prefix}", @@ -131,12 +132,12 @@ class Ernie4_5_VLMoE(nn.Layer): f"{prefix}.gate.weight_1", "gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } - self.mlp_image = FusedMoE( + self.image_fused_moe = FusedMoE( fd_config=fd_config, reduce_results=False, moe_intermediate_size=fd_config.model_config. @@ -148,9 +149,9 @@ class Ernie4_5_VLMoE(nn.Layer): moe_tag="Image", weight_key_map=weight_key_map, ) - self.mlp_image.extract_gate_correction_bias = self.extract_gate_correction_bias_image + self.image_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_image else: - self.mlp_image = Ernie4_5_VLMLP( + self.image_fused_moe = Ernie4_5_VLMLP( fd_config=fd_config, intermediate_size=fd_config.model_config.intermediate_size, prefix=f"{prefix}", @@ -185,10 +186,10 @@ class Ernie4_5_VLMoE(nn.Layer): return gate_correction_bias_tensor[1].unsqueeze(0) def load_state_dict(self, state_dict): - self.mlp_text.load_state_dict(state_dict) - self.mlp_image.load_state_dict(state_dict) - if self.mlp_text.moe_use_gate_correction_bias: - state_dict.pop(self.mlp_text.gate_correction_bias_key) + self.text_fused_moe.load_state_dict(state_dict) + self.image_fused_moe.load_state_dict(state_dict) + if self.text_fused_moe.moe_use_gate_correction_bias: + state_dict.pop(self.text_fused_moe.gate_correction_bias_key) if self.num_shared_experts > 0: self.share_experts.load_state_dict(state_dict) @@ -205,8 +206,8 @@ class Ernie4_5_VLMoE(nn.Layer): vl_moe_meta.image_index, True, ) - text_out = self.mlp_text(vl_moe_meta.text_input) - image_out = self.mlp_image(vl_moe_meta.image_input) + text_out = self.text_fused_moe(vl_moe_meta.text_input) + image_out = self.image_fused_moe(vl_moe_meta.image_input) text_image_gather_scatter( hidden_states, text_out, @@ -217,7 +218,7 @@ class Ernie4_5_VLMoE(nn.Layer): False, ) else: - hidden_states = self.mlp_text(hidden_states) + hidden_states = self.text_fused_moe(hidden_states) if self.num_shared_experts > 0: hidden_states += share_experts_out if self.tp_size > 1: @@ -342,7 +343,7 @@ class Ernie4_5_VLModel(nn.Layer): self._dtype = fd_config.model_config.dtype fd_config.model_config.pretrained_config.prefix_name = "ernie" - self.embeddings = VocabParallelEmbedding( + self.embed_tokens = VocabParallelEmbedding( fd_config=fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, @@ -350,7 +351,7 @@ class Ernie4_5_VLModel(nn.Layer): prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"), ) - self.hidden_layers = nn.LayerList([ + self.layers = nn.LayerList([ Ernie4_5_VLDecoderLayer( fd_config=fd_config, prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}") @@ -373,11 +374,11 @@ class Ernie4_5_VLModel(nn.Layer): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.embeddings.load_state_dict(state_dict) + self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") - self.hidden_layers[i].load_state_dict(state_dict) + self.layers[i].load_state_dict(state_dict) def forward( self, @@ -391,7 +392,7 @@ class Ernie4_5_VLModel(nn.Layer): image_index = None image_token_num = 0 - hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) # ----------------------- image_mask = ids_remove_padding == self.im_patch_id @@ -424,7 +425,7 @@ class Ernie4_5_VLModel(nn.Layer): residual = None for i in range(self.num_layers): - hidden_states, residual = self.hidden_layers[i]( + hidden_states, residual = self.layers[i]( forward_meta, hidden_states, residual, @@ -539,8 +540,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): self.vision_model.load_state_dict(state_dict) self.resampler_model.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.out_linear.weight.set_value( - self.ernie.embeddings.word_embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) else: self.lm_head.load_state_dict(state_dict) @@ -666,7 +667,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel): ] @classmethod - def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True): + def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True): """ get_tensor_parallel_mappings """ @@ -686,10 +687,10 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel): is_split=is_split, tensor_parallel_degree=config.tensor_parallel_degree, tensor_parallel_rank=config.tensor_parallel_rank, - num_attention_heads=config.vision_config.num_heads, - num_key_value_heads=config.vision_config.num_heads, - head_dim=config.vision_config.hidden_size - // config.vision_config.num_heads, + num_attention_heads=config.vision_config.get("num_heads"), + num_key_value_heads=config.vision_config.get("num_heads"), + head_dim=config.vision_config.get("hidden_size") + // config.vision_config.get("num_heads"), ) def get_tensor_parallel_split_mappings( @@ -754,7 +755,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel): config.prefix_name, ) vision_mappings = get_vison_parallel_split_mappings( - config.vision_config.depth + config.vision_config.get("depth") ) return {**mappings, **vision_mappings} diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 81e004107..3ffc7874e 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -48,7 +48,7 @@ class Qwen2MLP(nn.Layer): ) -> None: super().__init__() self.nranks = fd_config.parallel_config.tensor_parallel_size - self.gate_up_proj = MergedColumnParallelLinear( + self.up_gate_proj = MergedColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.up_gate_proj", input_size=fd_config.model_config.hidden_size, @@ -67,20 +67,20 @@ class Qwen2MLP(nn.Layer): self.act_fn = SiluAndMul( fd_config=fd_config, - bias=getattr(self.gate_up_proj, "linear_bias", None), + bias=getattr(self.up_gate_proj, "bias", None), act_method=fd_config.model_config.hidden_act, ) def load_state_dict(self, state_dict): """ """ - self.gate_up_proj.load_state_dict(state_dict) + self.up_gate_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict) def forward(self, x): """ """ - gate_up_out = self.gate_up_proj(x) + gate_up_out = self.up_gate_proj(x) act_out = self.act_fn(gate_up_out) down_out = self.down_proj(act_out) return down_out @@ -230,7 +230,7 @@ class Qwen2Model(nn.Layer): self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "qwen2" - self.embeddings = VocabParallelEmbedding( + self.embed_tokens = VocabParallelEmbedding( fd_config=fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, @@ -261,7 +261,7 @@ class Qwen2Model(nn.Layer): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.embeddings.load_state_dict(state_dict) + self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") @@ -275,7 +275,7 @@ class Qwen2Model(nn.Layer): """ """ - hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) residual = None @@ -303,7 +303,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): super(Qwen2ForCausalLM, self).__init__(fd_config) self.fd_config =fd_config - self.model = Qwen2Model(fd_config=fd_config) + self.qwen2 = Qwen2Model(fd_config=fd_config) self.ori_vocab_size = fd_config.model_config.ori_vocab_size @@ -330,7 +330,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.model.load_state_dict(state_dict) + self.qwen2.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): @@ -349,7 +349,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): ): """ """ - hidden_states = self.model(ids_remove_padding=ids_remove_padding, + hidden_states = self.qwen2(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) return hidden_states diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 5a75a868e..4f7642bee 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -166,7 +166,7 @@ class Qwen3Model(nn.Layer): self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "model" - self.embeddings = VocabParallelEmbedding( + self.embed_tokens = VocabParallelEmbedding( fd_config=fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, @@ -197,7 +197,7 @@ class Qwen3Model(nn.Layer): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.embeddings.load_state_dict(state_dict) + self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") @@ -210,7 +210,7 @@ class Qwen3Model(nn.Layer): ): """ """ - hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) residual = None @@ -266,8 +266,8 @@ class Qwen3ForCausalLM(ModelForCasualLM): """ self.model.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.out_linear.weight.set_value( - self.model.embeddings.word_embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.model.embed_tokens.embeddings.weight.transpose([1, 0])) else: self.lm_head.load_state_dict(state_dict) diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index b222f48ab..11d387a54 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -50,7 +50,7 @@ class Qwen3MLP(nn.Layer): super().__init__() self.nranks = fd_config.parallel_config.tensor_parallel_size - self.gate_up_proj = MergedColumnParallelLinear( + self.up_gate_proj = MergedColumnParallelLinear( fd_config, prefix=f"{prefix}.up_gate_proj", input_size=fd_config.model_config.hidden_size, @@ -69,20 +69,20 @@ class Qwen3MLP(nn.Layer): self.act_fn = SiluAndMul( fd_config, - bias=getattr(self.gate_up_proj, "linear_bias", None), + bias=getattr(self.up_gate_proj, "bias", None), act_method=fd_config.model_config.hidden_act, ) def load_state_dict(self, state_dict): """ """ - self.gate_up_proj.load_state_dict(state_dict) + self.up_gate_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict) def forward(self, x): """ """ - gate_up_out = self.gate_up_proj(x) + gate_up_out = self.up_gate_proj(x) act_out = self.act_fn(gate_up_out) down_out = self.down_proj(act_out) return down_out @@ -108,9 +108,9 @@ class Qwen3DecoderLayer(nn.Layer): weight_key_map = { "gate_weight_key": f"{prefix}.mlp.gate.weight", - "ffn1_expert_weight_key": + "up_gate_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight", - "ffn2_expert_weight_key": + "down_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.down_proj.weight", } @@ -201,7 +201,7 @@ class Qwen3MoeModel(nn.Layer): self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "model" - self.embeddings = VocabParallelEmbedding( + self.embed_tokens = VocabParallelEmbedding( fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, @@ -232,7 +232,7 @@ class Qwen3MoeModel(nn.Layer): A dictionary containing model parameters, where keys are parameter names and values are NumPy arrays or PaddlePaddle tensors. """ - self.embeddings.load_state_dict(state_dict) + self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") @@ -245,7 +245,7 @@ class Qwen3MoeModel(nn.Layer): ): """ """ - hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) residual = None diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 0bf40611e..14a9edfad 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -311,18 +311,18 @@ def w4a8_weight_convert(state_dict): w4a8_weight_bites_layers_map = {} w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = [] w4a8_weight_bites_layers_map["out_gemm_bits_map"] = [] - w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"] = [] - w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"] = [] + w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"] = [] + w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"] = [] for name_keys, gemm_bits in w4a8_weight_bites_name_map.items(): if "qkv_proj" in name_keys: w4a8_weight_bites_layers_map["qkv_gemm_bits_map"].append(gemm_bits) elif "out_proj" in name_keys: w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits) elif "linear1" in name_keys: - w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"].append( + w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append( gemm_bits) elif "linear2" in name_keys: - w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"].append( + w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append( gemm_bits) logger.debug( f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}") diff --git a/fastdeploy/model_executor/ops/iluvatar/moe_ops.py b/fastdeploy/model_executor/ops/iluvatar/moe_ops.py index 327cffb6e..ad77f8b69 100644 --- a/fastdeploy/model_executor/ops/iluvatar/moe_ops.py +++ b/fastdeploy/model_executor/ops/iluvatar/moe_ops.py @@ -15,9 +15,10 @@ """ from typing import Optional + import paddle -from paddle.nn.quant import weight_only_linear from paddle.incubate.nn.functional import swiglu +from paddle.nn.quant import weight_only_linear def group_gemm( @@ -71,31 +72,31 @@ def group_gemm( def iluvatar_moe_expert_ffn( permute_input: paddle.Tensor, tokens_expert_prefix_sum: paddle.Tensor, - ffn1_weight: paddle.Tensor, - ffn2_weight: paddle.Tensor, - ffn1_bias: Optional[paddle.Tensor], - ffn1_scale: Optional[paddle.Tensor], - ffn2_scale: Optional[paddle.Tensor], - ffn2_in_scale: Optional[paddle.Tensor], + up_gate_proj_weight: paddle.Tensor, + down_proj_weight: paddle.Tensor, + up_gate_proj_bias: Optional[paddle.Tensor], + up_gate_proj_scale: Optional[paddle.Tensor], + down_proj_scale: Optional[paddle.Tensor], + down_proj_in_scale: Optional[paddle.Tensor], expert_idx_per_token: Optional[paddle.Tensor], quant_method: str, used_in_ep_low_latency: bool, ): - assert ffn1_bias is None - assert ffn1_scale is not None - assert ffn2_scale is not None - assert ffn2_in_scale is None + assert up_gate_proj_bias is None + assert up_gate_proj_scale is not None + assert down_proj_scale is not None + assert down_proj_in_scale is None assert expert_idx_per_token is None assert quant_method in ("weight_only_int8") assert not used_in_ep_low_latency tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu") - ffn1_output = paddle.empty([permute_input.shape[0], ffn1_weight.shape[1]], + up_gate_proj_output = paddle.empty([permute_input.shape[0], up_gate_proj_weight.shape[1]], dtype=permute_input.dtype) - group_gemm(permute_input, tokens_expert_prefix_sum_cpu, ffn1_weight, - ffn1_scale, ffn1_output) - act_out = swiglu(ffn1_output) - output = paddle.empty([act_out.shape[0], ffn2_weight.shape[1]], + group_gemm(permute_input, tokens_expert_prefix_sum_cpu, up_gate_proj_weight, + up_gate_proj_scale, up_gate_proj_output) + act_out = swiglu(up_gate_proj_output) + output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]], dtype=act_out.dtype) - group_gemm(act_out, tokens_expert_prefix_sum_cpu, ffn2_weight, ffn2_scale, + group_gemm(act_out, tokens_expert_prefix_sum_cpu, down_proj_weight, down_proj_scale, output) return output diff --git a/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py b/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py index 0efe07afa..fe279a8e5 100644 --- a/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py +++ b/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py @@ -386,19 +386,19 @@ def moe_wint2_ffn_kernel( def fused_moe_wint2_impl( hidden_states, - ffn1_quant_weight, - ffn2_quant_weight, + up_gate_proj_quant_weight, + down_proj_quant_weight, topk_weights, topk_ids, # inplace: bool = False, - ffn1_weight_scale=None, - ffn2_weight_scale=None, - ffn1_super_scales=None, - ffn2_super_scales=None, - ffn1_code_scale=None, - ffn2_code_scale=None, - ffn1_code_zp=None, - ffn2_code_zp=None, + up_gate_proj_weight_scale=None, + down_proj_weight_scale=None, + up_gate_proj_super_scales=None, + down_proj_super_scales=None, + up_gate_proj_code_scale=None, + down_proj_code_scale=None, + up_gate_proj_code_zp=None, + down_proj_code_zp=None, group_size=64, bit="wint2", ): @@ -408,22 +408,22 @@ def fused_moe_wint2_impl( # Check constraints. # A: [M, K] # B: [E, K, N] - # assert hidden_states.shape[1] == ffn1_weight_scale.shape[1], - # f"Hidden size mismatch, {hidden_states.shape[1]} != {ffn1_quant_weight.shape[1]}" + # assert hidden_states.shape[1] == up_gate_proj_weight_scale.shape[1], + # f"Hidden size mismatch, {hidden_states.shape[1]} != {up_gate_proj_quant_weight.shape[1]}" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert ffn1_quant_weight.is_contiguous( + assert up_gate_proj_quant_weight.is_contiguous( ), "Expert weights1 must be contiguous" - assert ffn2_quant_weight.is_contiguous( + assert down_proj_quant_weight.is_contiguous( ), "Expert weights2 must be contiguous" assert group_size > 0, "Group size must be greater than 0" num_tokens, K = hidden_states.shape - E, _, N = ffn1_quant_weight.shape + E, _, N = up_gate_proj_quant_weight.shape M = num_tokens if group_size < 0: - group_size = K // ffn1_weight_scale.shape[1] + group_size = K // up_gate_proj_weight_scale.shape[1] top_k = topk_ids.shape[1] @@ -448,12 +448,12 @@ def fused_moe_wint2_impl( invoke_fused_moe_kernel( A=hidden_states, - B=ffn1_quant_weight, + B=up_gate_proj_quant_weight, C=intermediate_cache1, - B_scale=ffn1_weight_scale, - B_super_scale=ffn1_super_scales, - B_code_scale=ffn1_code_scale, - B_code_zp=ffn1_code_zp, + B_scale=up_gate_proj_weight_scale, + B_super_scale=up_gate_proj_super_scales, + B_code_scale=up_gate_proj_code_scale, + B_code_zp=up_gate_proj_code_zp, topk_weights=topk_weights, topk_ids=topk_ids, sorted_token_ids=sorted_token_ids, @@ -469,12 +469,12 @@ def fused_moe_wint2_impl( invoke_fused_moe_kernel( A=intermediate_cache2, - B=ffn2_quant_weight, + B=down_proj_quant_weight, C=intermediate_cache3, - B_scale=ffn2_weight_scale, - B_super_scale=ffn2_super_scales, - B_code_scale=ffn2_code_scale, - B_code_zp=ffn2_code_zp, + B_scale=down_proj_weight_scale, + B_super_scale=down_proj_super_scales, + B_code_scale=down_proj_code_scale, + B_code_zp=down_proj_code_zp, topk_weights=topk_weights, topk_ids=topk_ids, sorted_token_ids=sorted_token_ids, @@ -491,37 +491,37 @@ def fused_moe_wint2_impl( def fused_moe_wint2_triton( hidden_states, - ffn1_quant_weight, - ffn2_quant_weight, + up_gate_proj_quant_weight, + down_proj_quant_weight, scores, gate_correction_bias, topk, - ffn1_weight_scale, - ffn2_weight_scale, - ffn1_super_scales, - ffn2_super_scales, - ffn1_code_scale, - ffn2_code_scale, - ffn1_code_zp, - ffn2_code_zp, + up_gate_proj_weight_scale, + down_proj_weight_scale, + up_gate_proj_super_scales, + down_proj_super_scales, + up_gate_proj_code_scale, + down_proj_code_scale, + up_gate_proj_code_zp, + down_proj_code_zp, ): """ Fuse MoE with WINT2 quantization scheme and Triton backend. Args: hidden_states: input tensor. - ffn1_quant_weight: ffn1 weight matrix for experts. - ffn2_quant_weight: ffn2 weight matrix for experts. + up_gate_proj_quant_weight: up_gate_proj weight matrix for experts. + down_proj_quant_weight: down_proj weight matrix for experts. scores: gate scores. gate_correction_bias: bias correction for gates. topk: number of experts to use. - ffn1_weight_scale: scaling factor for ffn1_quant_weight. - ffn2_weight_scale: scaling factor for ffn2_quant_weight. - ffn1_super_scales: super scaling factor for ffn1_scale. - ffn2_super_scales: super scaling factor for ffn2_weight_scale. - ffn1_code_scale: code scaling factor for ffn1_quant_weight. - ffn2_code_scale: code scaling factor for ffn2_quant_weight. - ffn1_code_zp: code zero point for ffn1_quant_weight. - ffn2_code_zp: code zero point for ffn2_quant_weight. + up_gate_proj_weight_scale: scaling factor for up_gate_proj_quant_weight. + down_proj_weight_scale: scaling factor for down_proj_quant_weight. + up_gate_proj_super_scales: super scaling factor for up_gate_proj_scale. + down_proj_super_scales: super scaling factor for down_proj_weight_scale. + up_gate_proj_code_scale: code scaling factor for up_gate_proj_quant_weight. + down_proj_code_scale: code scaling factor for down_proj_quant_weight. + up_gate_proj_code_zp: code zero point for up_gate_proj_quant_weight. + down_proj_code_zp: code zero point for down_proj_quant_weight. Returns: output tensor. """ @@ -533,17 +533,17 @@ def fused_moe_wint2_triton( return fused_moe_wint2_impl( hidden_states, - ffn1_quant_weight, - ffn2_quant_weight, + up_gate_proj_quant_weight, + down_proj_quant_weight, topk_weights, topk_ids, - ffn1_weight_scale, - ffn2_weight_scale, - ffn1_super_scales, - ffn2_super_scales, - ffn1_code_scale, - ffn2_code_scale, - ffn1_code_zp, - ffn2_code_zp, + up_gate_proj_weight_scale, + down_proj_weight_scale, + up_gate_proj_super_scales, + down_proj_super_scales, + up_gate_proj_code_scale, + down_proj_code_scale, + up_gate_proj_code_zp, + down_proj_code_zp, bit="wint2", ) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index ab16741e4..99ab3455f 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -24,11 +24,14 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.model_loader import ModelRegistry from fastdeploy.model_executor.models.ernie4_5_moe import \ Ernie4_5_MoeForCausalLM +from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import \ + Ernie4_5_VLMoeForConditionalGeneration from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM from fastdeploy.rl.rollout_config import RolloutModelConfig + class RolloutModel(nn.Layer): """Main model class for rollout operations, supports multimodal components for train.""" @@ -36,55 +39,26 @@ class RolloutModel(nn.Layer): """Initialize with FastDeploy configuration.""" super(RolloutModel, self).__init__() self.fd_config = rollout_model_config.initialize() - self._init_models() - - def _init_models(self): - """Initialize all model components including multimodal if needed.""" - self.is_vl = "VL" in self.fd_config.model_config.architectures[0] - self.rollout_model = self._load_primary_model() - self.rollout_models = [self.rollout_model] - - if self.is_vl: - self._init_multimodal_models() - self.rollout_models.extend( - [self.vision_model, self.resampler_model]) - - def _init_multimodal_models(self): - """Initialize vision and resampler components for multimodal models.""" - # TODO:(gaoziyuan) Implement actual initialization - self.vision_model = nn.Layer() - self.resampler_model = nn.Layer() - - def _load_primary_model(self): - """Load main model from loader based on config.""" - if "VL" in self.fd_config.model_config.architectures[0]: - logger.error("Loaded Vision Language model, not support now") + self._init_model() + def _init_model(self): + """Load model from loader based on config.""" context = paddle.LazyGuard() architectures = f"{self.fd_config.model_config.architectures[0]}RL" with context: model_cls = ModelRegistry.get_class(architectures) model = model_cls(self.fd_config) - model.eval() - return model + self.rollout_model = model.eval() def get_name_mappings_to_training(self) -> Dict[str, str]: """Get parameter name mappings between rollout and training models.""" - mappings = {} - for model in self.rollout_models: - mappings.update( - getattr(model, "get_name_mappings_to_training", lambda: {})()) - return mappings + return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})() @paddle.no_grad() def state_dict(self): """state_dict""" - all_params = {} - for model in self.rollout_models: - for name, param in model.state_dict().items(): - all_params[name] = param - return all_params + return self.rollout_model.state_dict() class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): @@ -113,98 +87,159 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): # Initialize mapping dictionary infer_to_train = {} - infer_base_name = "model" - train_base_name = "ernie" + base_name = "ernie" # Static mappings (non-layer specific) static_mappings = { - f"{infer_base_name}.embeddings.word_embeddings.weight": - f"{train_base_name}.embed_tokens.weight", - f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight", - "lm_head.out_linear.weight": "lm_head.weight" + f"{base_name}.embed_tokens.embeddings.weight": + f"{base_name}.embed_tokens.weight", + "lm_head.linear.weight": "lm_head.weight" } - if self.fd_config.model_config.get("weight_sharing", False): + if self.fd_config.model_config.get("tie_word_embeddings", False): # Support tie_word_embeddings logger.debug("enable tie_word_embeddings") - static_mappings.pop("lm_head.out_linear.weight") + static_mappings.pop("lm_head.linear.weight") infer_to_train.update(static_mappings) - infer_base_name = infer_base_name + ".hidden_layers" - train_base_name = train_base_name + ".layers" + base_name = base_name + ".layers" # Helper function to add layer mappings - def _add_layer_mappings(layer_idx, is_moe_layer=False): - # Handle special case for layer 0's input layernorm - for ph in place_holders: - infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}" - train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}" - infer_to_train[infer_key] = train_key + def _add_layer_mappings(layer_idx: int): + # MoE specific mappings + infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \ + f"{base_name}.{layer_idx}.mlp.gate.weight" - # Common attention mappings - for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}" + if self.fd_config.model_config.moe_use_aux_free: + infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \ + f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}" - - # Post-attention layernorm - for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \ - f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}" - - if not is_moe_layer: - # Dense FFN mappings + # MoE experts mappings + for expert_idx in range(self.fd_config.model_config.moe_num_experts): for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.mlp.up_gate_proj.{ph}" + # up_gate_proj (up_gate_proj) + up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.up_gate_proj_weight" + if up_gate_proj_key not in infer_to_train: + infer_to_train[up_gate_proj_key] = [] + infer_to_train[up_gate_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" + ) - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}" - else: - # MoE specific mappings - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \ - f"{train_base_name}.{layer_idx}.mlp.gate.weight" - - if self.fd_config.model_config.moe_use_aux_free: - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \ - f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" - - # Support shared experts - if self.fd_config.model_config.get( - "moe_num_shared_experts") > 0: - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \ - f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight" - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \ - f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight" - - # MoE experts mappings - for expert_idx in range(self.fd_config.model_config.moe_num_experts): - for ph in place_holders: - # FFN1 (up_gate_proj) - ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn1_weight" - if ffn1_key not in infer_to_train: - infer_to_train[ffn1_key] = [] - infer_to_train[ffn1_key].append( - f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" - ) - - # FFN2 (down_proj) - ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn2_weight" - if ffn2_key not in infer_to_train: - infer_to_train[ffn2_key] = [] - infer_to_train[ffn2_key].append( - f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" - ) - - # Process non-MoE layers - for layer_idx in range( - self.fd_config.model_config.moe_layer_start_index): - _add_layer_mappings(layer_idx, is_moe_layer=False) + # down_proj (down_proj) + down_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.down_proj_weight" + if down_proj_key not in infer_to_train: + infer_to_train[down_proj_key] = [] + infer_to_train[down_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" + ) + assert isinstance(self.fd_config.model_config.moe_layer_start_index, int) # Process MoE layers for layer_idx in range(self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.num_hidden_layers): - _add_layer_mappings(layer_idx, is_moe_layer=True) + _add_layer_mappings(layer_idx) + + return infer_to_train + + +class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGeneration): + """ + Ernie4_5_VLMoeForConditionalGenerationRL + """ + + def __init__(self, fd_config: FDConfig): + """ + Args: + fd_config (FDConfig): Configurations for the LLM model. + """ + super(Ernie4_5_VLMoeForConditionalGenerationRL, self).__init__(fd_config) + + @classmethod + def name(self): + """name""" + return "Ernie4_5_VLMoeForConditionalGenerationRL" + + def get_name_mappings_to_training(self): + """Generate mapping between inference and training parameter for RL(donot delete!).""" + have_bias = self.fd_config.model_config.get("have_norm_bias", False) + # Prepare placeholders + place_holders = ["weight"] + (["bias"] if have_bias else []) + + # Initialize mapping dictionary + infer_to_train = {} + + base_name = "ernie" + # Static mappings (non-layer specific) + static_mappings = { + f"{base_name}.embed_tokens.embeddings.weight": + f"{base_name}.embed_tokens.weight", + "lm_head.linear.weight": "lm_head.weight" + } + if self.fd_config.model_config.get("tie_word_embeddings", False): + # Support tie_word_embeddings + logger.debug("enable tie_word_embeddings") + static_mappings.pop("lm_head.linear.weight") + infer_to_train.update(static_mappings) + + base_name = base_name + ".layers" + + # Helper function to add layer mappings + def _add_layer_mappings(layer_idx: int, moe_tag: str): + # MoE specific mappings + infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = f"{base_name}.{layer_idx}.mlp.gate.weight" if moe_tag == "text" else f"{base_name}.{layer_idx}.mlp.gate.weight_1" + + if self.fd_config.model_config.moe_use_aux_free: + infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias"] = \ + f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" + + # MoE experts mappings + assert isinstance(self.fd_config.model_config.moe_num_experts, list) + if moe_tag == "text": + expert_idx_start = 0 + expert_idx_end = self.fd_config.model_config.moe_num_experts[0] + else: + expert_idx_start = self.fd_config.model_config.moe_num_experts[0] + expert_idx_end = self.fd_config.model_config.moe_num_experts[1] + + for expert_idx in range(expert_idx_start, expert_idx_end): + for ph in place_holders: + # up_gate_proj (up_gate_proj) + up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight" + if up_gate_proj_key not in infer_to_train: + infer_to_train[up_gate_proj_key] = [] + infer_to_train[up_gate_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" + ) + + # down_proj (down_proj) + down_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight" + if down_proj_key not in infer_to_train: + infer_to_train[down_proj_key] = [] + infer_to_train[down_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" + ) + + moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index + if isinstance(moe_layer_start_index, int): + text_moe_layer_start_index = moe_layer_start_index + image_moe_layer_start_index = moe_layer_start_index + else: + text_moe_layer_start_index = moe_layer_start_index[0] + image_moe_layer_start_index = moe_layer_start_index[1] + + moe_layer_end_index = self.fd_config.model_config.moe_layer_end_index + if moe_layer_end_index is None: + text_moe_layer_end_index = self.fd_config.model_config.num_hidden_layers + image_moe_layer_end_index = self.fd_config.model_config.num_hidden_layers + elif isinstance(moe_layer_end_index, int): + text_moe_layer_end_index = moe_layer_end_index + image_moe_layer_end_index = moe_layer_end_index + else: + text_moe_layer_end_index = moe_layer_end_index[0] + image_moe_layer_end_index = moe_layer_end_index[1] + # Process MoE layers + for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index): + _add_layer_mappings(layer_idx, "text") + for layer_idx in range(image_moe_layer_start_index, image_moe_layer_end_index): + _add_layer_mappings(layer_idx, "image") return infer_to_train @@ -234,48 +269,23 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM): # Initialize mapping dictionary infer_to_train = {} - infer_base_name = "model" - train_base_name = "qwen2" + base_name = "qwen2" # Static mappings (non-layer specific) static_mappings = { - f"{infer_base_name}.embeddings.word_embeddings.weight": - f"{train_base_name}.embed_tokens.weight", - f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight", - "lm_head.out_linear.weight": "lm_head.weight" + f"{base_name}.embed_tokens.embeddings.weight": + f"{base_name}.embed_tokens.weight", + "lm_head.linear.weight": "lm_head.weight" } infer_to_train.update(static_mappings) - infer_base_name = infer_base_name + ".layers" - train_base_name = train_base_name + ".layers" + base_name = base_name + ".layers" # Helper function to add layer mappings def _add_layer_mappings(layer_idx): - # Handle special case for layer 0's input layernorm and attn o_proj - for ph in place_holders: - infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}" - train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}" - infer_to_train[infer_key] = train_key - - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}" - - # qwen qkv proj need bias - for ph in ["weight", "bias"]: - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}" - - # Post-attention layernorm - for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \ - f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}" - # FFN mappings for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}" - - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}" + infer_to_train[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = \ + f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}" for layer_idx in range( self.fd_config.model_config.num_hidden_layers): @@ -309,95 +319,49 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM): # Initialize mapping dictionary infer_to_train = {} - infer_base_name = "model" - train_base_name = "model" + base_name = "model" # Static mappings (non-layer specific) static_mappings = { - f"{infer_base_name}.embeddings.word_embeddings.weight": - f"{train_base_name}.embed_tokens.weight", - f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight", - "lm_head.out_linear.weight": "lm_head.weight" + f"{base_name}.embed_tokens.embeddings.weight": + f"{base_name}.embed_tokens.weight", + "lm_head.linear.weight": "lm_head.weight" } infer_to_train.update(static_mappings) - infer_base_name = infer_base_name + ".layers" - train_base_name = train_base_name + ".layers" + base_name = base_name + ".layers" # Helper function to add layer mappings - def _add_layer_mappings(layer_idx, is_moe_layer=False): - # Handle special case for layer 0's input layernorm and attn o_proj - for ph in place_holders: - infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}" - train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}" - infer_to_train[infer_key] = train_key + def _add_layer_mappings(layer_idx: int): + # MoE specific mappings + infer_to_train[f"{base_name}.{layer_idx}.mlp.gate_weight"] = \ + f"{base_name}.{layer_idx}.mlp.gate.weight" - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}" + if self.fd_config.moe_config.moe_use_aux_free: + infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \ + f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" - # qwen q_norm/k_norm - for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.q_norm.ln_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.q_norm.{ph}" - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.k_norm.ln_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.k_norm.{ph}" - - # qwen qkv proj - for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}" - - # Post-attention layernorm - for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \ - f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}" - - if not is_moe_layer: - # FFN mappings + # MoE experts mappings + for expert_idx in range(self.fd_config.moe_config.num_experts): for ph in place_holders: - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}" + # up_gate_proj (up_gate_proj) + up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.up_gate_proj_weight" + if up_gate_proj_key not in infer_to_train: + infer_to_train[up_gate_proj_key] = [] + infer_to_train[up_gate_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" + ) - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \ - f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}" - else: - # MoE specific mappings - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_weight"] = \ - f"{train_base_name}.{layer_idx}.mlp.gate.weight" - - if self.fd_config.moe_config.moe_use_aux_free: - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \ - f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" - - # Support shared experts - if self.fd_config.model_config.get( - "moe_num_shared_experts", 0) > 0: - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \ - f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight" - infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \ - f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight" - - # MoE experts mappings - for expert_idx in range(self.fd_config.moe_config.num_experts): - for ph in place_holders: - # FFN1 (up_gate_proj) - ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn1_weight" - if ffn1_key not in infer_to_train: - infer_to_train[ffn1_key] = [] - infer_to_train[ffn1_key].append( - f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" - ) - - # FFN2 (down_proj) - ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn2_weight" - if ffn2_key not in infer_to_train: - infer_to_train[ffn2_key] = [] - infer_to_train[ffn2_key].append( - f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" - ) + # down_proj (down_proj) + down_proj_key = f"{base_name}.{layer_idx}.mlp.down_proj_weight" + if down_proj_key not in infer_to_train: + infer_to_train[down_proj_key] = [] + infer_to_train[down_proj_key].append( + f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" + ) # Process MoE layers for layer_idx in range(self.fd_config.model_config.num_hidden_layers): - _add_layer_mappings(layer_idx, is_moe_layer=True) + _add_layer_mappings(layer_idx) return infer_to_train @@ -417,4 +381,4 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM): @classmethod def name(self): """name""" - return "Qwen3ForCausalLMRL" \ No newline at end of file + return "Qwen3ForCausalLMRL"