refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* refactor rl get_name_mappings_to_training

* fix tp>1

* change variable name(ffn1->up_gate_proj/ffn2->down_proj)

* change variable name(linear_weight->weight/linear_bias->bias)

* add rl names mapping for vl

* fix ernie 0.3B error

* fix develop code

* fix
This commit is contained in:
Yuanle Liu
2025-07-15 22:31:42 +08:00
committed by GitHub
parent e7bcbbab52
commit 61b3997b85
47 changed files with 1591 additions and 1629 deletions

View File

@@ -116,11 +116,11 @@ PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
paddle::Tensor FusedExpertMoeFunc(
const paddle::Tensor &input, const paddle::Tensor &gate_weight,
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
const paddle::optional<paddle::Tensor> &ffn1_bias,
const paddle::optional<paddle::Tensor> &ffn1_scale,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &ffn2_scale,
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_scale,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_scale,
const std::string &quant_method, const int moe_topk,
const bool norm_topk_prob, const bool group_moe);
@@ -149,7 +149,7 @@ MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
std::vector<paddle::Tensor>
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
const paddle::Tensor &topk_weights,
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
const paddle::optional<paddle::Tensor> &up_gate_proj_in_scale,
const std::vector<int> &token_nums_per_expert,
const int token_nums_this_rank,
const std::string &moe_quant_type);
@@ -173,7 +173,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor);
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
@@ -182,35 +182,35 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor);
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
@@ -816,7 +816,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* ep_moe_dispatch
*/
m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"),
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("ffn1_in_scale"),
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"),
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
py::arg("moe_quant_type"), "ep moe export dispatch function");
@@ -824,7 +824,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
py::arg("top_k_indices"), py::arg("ffn2_bias"),
py::arg("top_k_indices"), py::arg("down_proj_bias"),
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
"ep moe export combine function");
@@ -866,7 +866,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
*/
m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"),
py::arg("top_k_weight"), py::arg("permute_indices_per_token"),
py::arg("top_k_indices"), py::arg("ffn2_bias"),
py::arg("top_k_indices"), py::arg("down_proj_bias"),
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
"moe export reduce function");

View File

@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
ffn1_n=7168
ffn1_k=8192
up_gate_proj_n=7168
up_gate_proj_k=8192
ffn2_n=8192
ffn2_k=3584
rm -rf ffn1_7168_8192.log
rm -rf ffn2_8192_3584.log
down_proj_n=8192
down_proj_k=3584
rm -rf up_gate_proj_7168_8192.log
rm -rf down_proj_8192_3584.log
num_experts=8
for tokens_per_expert in 12
do
wait
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 1 0 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 1 0 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
done
wait
echo "#### finish ####"

View File

@@ -161,7 +161,7 @@ __global__ void combine_prmt_back_kernel(
expanded_permuted_rows + expanded_permuted_row * cols; // prmt后的位置对应的值
Load<T, VEC_SIZE>(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, &load_vec);
const int expert_idx = expert_for_source_row[k_offset]; // 当前位置对应的专家
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的ffn2的bias
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的down_proj的bias
if (bias_ptr) {
Load<T, VEC_SIZE>(bias_ptr + tid * VEC_SIZE, &bias_vec);
#pragma unroll
@@ -188,7 +188,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
const paddle::Tensor& expert_scales_float,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int num_rows,
@@ -206,7 +206,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
combine_prmt_back_kernel<<<gridx, threads, 0, stream>>>(
ffn_out.data<data_t>(),
output->data<data_t>(),
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
expert_scales_float.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
@@ -223,7 +223,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor& expert_scales_float, // dst_weights
const paddle::Tensor& permute_indices_per_token, // permute_indices_per_token
const paddle::Tensor& top_k_indices, // dst_indices
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
@@ -242,7 +242,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
expert_scales_float,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
@@ -255,7 +255,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
expert_scales_float,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
@@ -274,7 +274,7 @@ __global__ void permute_x_kernel(const T *src_x,
const int64_t *topk_idx,
const float *topk_weights,
const int *token_nums_per_expert,
const float *ffn1_in_scale,
const float *up_gate_proj_in_scale,
const int moe_topk,
const int num_rows,
const int token_nums_this_rank,
@@ -327,9 +327,9 @@ __global__ void permute_x_kernel(const T *src_x,
// cp x
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
Load<T, vec_size>(&src_x[s_token_idx * hidden_size + v_id * vec_size], &src_vec);
if (ffn1_in_scale) {
if (up_gate_proj_in_scale) {
for (int i = 0; i < vec_size; i++) {
float quant_value = max_bound * ffn1_in_scale[expert_now] * static_cast<float>(src_vec[i]);
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
if (RoundType == 0) {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
} else {
@@ -353,7 +353,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::Tensor& token_nums_per_expert,
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
const std::string& moe_quant_type,
const int moe_topk,
const int num_rows,
@@ -383,7 +383,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -404,7 +404,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -427,7 +427,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -448,7 +448,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -472,7 +472,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank,
const std::string& moe_quant_type) {
@@ -516,7 +516,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
topk_ids,
topk_weights,
num_experts_per_rank_tensor,
ffn1_in_scale,
up_gate_proj_in_scale,
moe_quant_type,
moe_topk,
num_rows,
@@ -536,7 +536,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
topk_ids,
topk_weights,
num_experts_per_rank_tensor,
ffn1_in_scale,
up_gate_proj_in_scale,
moe_quant_type,
moe_topk,
num_rows,
@@ -568,7 +568,7 @@ std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& topk_ids_shape,
const std::vector<int64_t>& topk_weights_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_in_scale_dtype,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_in_scale_dtype,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank) {
int token_rows = -1;
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
.Inputs({"input", "topk_ids", "topk_weights",
paddle::Optional("ffn1_in_scale")})
paddle::Optional("up_gate_proj_in_scale")})
.Outputs({"permute_input",
"permute_indices_per_token",
"token_nums_per_expert_cumsum",

View File

@@ -54,12 +54,12 @@ void compute_total_rows_before_expert(int* sorted_indices,
template <paddle::DataType T>
void FusedMoeKernel(const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::Tensor& up_gate_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
@@ -84,12 +84,12 @@ void FusedMoeKernel(const paddle::Tensor& input,
moe_compute.ComputeFFN(&input,
&gate_weight,
&ffn1_weight,
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
&ffn2_weight,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
&up_gate_proj_weight,
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
&down_proj_weight,
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
@@ -102,12 +102,12 @@ void FusedMoeKernel(const paddle::Tensor& input,
paddle::Tensor FusedExpertMoeFunc(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
@@ -119,12 +119,12 @@ paddle::Tensor FusedExpertMoeFunc(
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
gate_weight,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
@@ -134,12 +134,12 @@ paddle::Tensor FusedExpertMoeFunc(
case paddle::DataType::FLOAT16:
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
gate_weight,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
@@ -155,24 +155,24 @@ paddle::Tensor FusedExpertMoeFunc(
std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
return {FusedExpertMoeFunc(input,
gate_weight,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_bias,
ffn2_scale,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_bias,
down_proj_scale,
quant_method,
moe_topk,
norm_topk_prob,
@@ -182,24 +182,24 @@ std::vector<paddle::Tensor> FusedExpertMoe(
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gate_weight_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
return {input_shape};
}
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gate_weight_dtype,
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
return {input_dtype};
}
@@ -230,12 +230,12 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
PD_BUILD_STATIC_OP(fused_expert_moe)
.Inputs({"input",
"gate_weight",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_bias"),
paddle::Optional("ffn2_scale")})
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_bias"),
paddle::Optional("down_proj_scale")})
.Outputs({"output"})
.Attrs({"quant_method:std::string",
"moe_topk:int",

View File

@@ -117,18 +117,18 @@ public:
void
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
const paddle::Tensor *ffn1_weight,
const paddle::Tensor *ffn1_scale, const paddle::Tensor *ffn1_bias,
const paddle::Tensor *ffn2_weight,
const paddle::Tensor *ffn2_scale, const paddle::Tensor *ffn2_bias,
const paddle::Tensor *up_gate_proj_weight,
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
const paddle::Tensor *down_proj_weight,
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
const bool group_moe, const bool norm_topk_prob,
const float routed_scaling_factor, const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
auto *output_ = output->data<T>();
auto stream = input->stream();
@@ -136,7 +136,7 @@ public:
auto input_type = input->dtype();
auto input_dims = input->dims();
auto ffn1_dims = ffn1_weight->dims();
auto up_gate_proj_dims = up_gate_proj_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
@@ -145,12 +145,12 @@ public:
}
const int64_t num_rows = token_num;
const int64_t hidden_size = ffn1_dims[1];
const int64_t hidden_size = up_gate_proj_dims[1];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
} else {
inter_dim = ffn1_dims[2];
inter_dim = up_gate_proj_dims[2];
}
if (gemm_method_ == "weight_only_int4") {
@@ -158,7 +158,7 @@ public:
}
const int64_t inter_size = inter_dim;
const int64_t num_experts = ffn1_dims[0];
const int64_t num_experts = up_gate_proj_dims[0];
const int64_t k = moe_topk;
int64_t bytes =
@@ -260,38 +260,38 @@ public:
total_rows_before_expert_, stream);
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments ffn1_quant_args;
typename Int8Traits::Arguments up_gate_proj_quant_args;
int8_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const uint8_t *>(ffn1_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn1_scale->data<T>()),
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
ffn1_quant_args, "none", stream);
up_gate_proj_quant_args, "none", stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments ffn1_quant_args;
typename Int4Traits::Arguments up_gate_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const cutlass::uint4b_t *>(
ffn1_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn1_scale->data<T>()),
up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
ffn1_quant_args, "none", stream);
up_gate_proj_quant_args, "none", stream);
} else {
typename Fp16Traits::Arguments ffn1_quant_args;
typename Fp16Traits::Arguments up_gate_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const NvType *>(ffn1_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
ffn1_quant_args, "none", stream);
up_gate_proj_quant_args, "none", stream);
}
if (moe_type == "ffn") {
@@ -304,35 +304,35 @@ public:
T *fc2_result = fc2_output_tensor.data<T>();
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments ffn2_quant_args;
typename Int8Traits::Arguments down_proj_quant_args;
int8_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const uint8_t *>(ffn2_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn2_scale->data<T>()),
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, ffn2_quant_args, stream);
num_experts, down_proj_quant_args, stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments ffn2_quant_args;
typename Int4Traits::Arguments down_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const cutlass::uint4b_t *>(
ffn2_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn2_scale->data<T>()),
down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, ffn2_quant_args, stream);
num_experts, down_proj_quant_args, stream);
} else {
typename Fp16Traits::Arguments ffn2_quant_args;
typename Fp16Traits::Arguments down_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const NvType *>(ffn2_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, ffn2_quant_args, stream);
num_experts, down_proj_quant_args, stream);
}
finalize_moe_routing_kernelLauncher<T>::run(

View File

@@ -24,12 +24,12 @@
template <paddle::DataType T>
void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method,
paddle::Tensor ffn_out,
@@ -51,11 +51,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
const int num_experts = ffn1_weight.dims()[0];
const int num_experts = up_gate_proj_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
assert(ffn1_weight.dims().size() == 3);
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
assert(up_gate_proj_weight.dims().size() == 3);
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
Allocator* allocator = paddle::GetAllocator(place);
@@ -96,8 +96,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
using NvType = typename traits_::DataType;
auto fc1_expert_biases =
ffn1_bias
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
up_gate_proj_bias
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())->data<data_t>()
: nullptr;
// This is a trick.
@@ -112,9 +112,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const uint8_t*>(ffn1_weight.data<int8_t>()),
reinterpret_cast<const uint8_t*>(up_gate_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
@@ -132,9 +132,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
int4_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const cutlass::uint4b_t*>(
ffn1_weight.data<int8_t>()),
up_gate_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
@@ -151,12 +151,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const int8_t *>(permute_input.data<int8_t>()),
reinterpret_cast<const cutlass::uint4b_t *>(
ffn1_weight.data<int8_t>()),
up_gate_proj_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<data_t>()),
nullptr, // ffn1_scale_dyquant
nullptr, // up_gate_proj_scale_dyquant
nullptr, // nf4_look_up_table
reinterpret_cast<NvType *>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -172,7 +172,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const NvType*>(ffn1_weight.data<data_t>()),
reinterpret_cast<const NvType*>(up_gate_proj_weight.data<data_t>()),
nullptr,
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
@@ -199,9 +199,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const uint8_t*>(ffn2_weight.data<int8_t>()),
reinterpret_cast<const uint8_t*>(down_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -218,9 +218,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
int4_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const cutlass::uint4b_t*>(
ffn2_weight.data<int8_t>()),
down_proj_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -232,17 +232,17 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
quant_args,
stream);
} else if (quant_method == "w4a8") {
data_t *ffn2_shift = nullptr;
data_t *ffn2_smooth = nullptr;
data_t *down_proj_shift = nullptr;
data_t *down_proj_smooth = nullptr;
Allocator::AllocationPtr int8_act_out;
int8_act_out = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
MoeFastHardamardWrapper<data_t, int8_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
ffn2_in_scale ? const_cast<paddle::Tensor*>(ffn2_in_scale.get_ptr())->data<float>() : nullptr,
down_proj_shift, // down_proj_shift->data<T>(),
down_proj_smooth, // down_proj_smooth->data<T>(),
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
1,
127.0,
-127.0,
@@ -254,12 +254,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
reinterpret_cast<const cutlass::uint4b_t *>(
ffn2_weight.data<int8_t>()),
down_proj_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<data_t>()),
nullptr, // ffn2_scale_dyquant
nullptr, // down_proj_scale_dyquant
nullptr, // reinterpret_cast<const int32_t*>(d_nf4_look_up_table), // nf4_look_up_table
reinterpret_cast<NvType *>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -275,7 +275,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const NvType*>(ffn2_weight.data<data_t>()),
reinterpret_cast<const NvType*>(down_proj_weight.data<data_t>()),
nullptr,
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -292,29 +292,29 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {
cudaCheckError();
const auto t_type = quant_method == "w4a8" ? ffn1_scale.get().dtype() : permute_input.dtype();
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
switch (t_type) {
case paddle::DataType::BFLOAT16:
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn2_in_scale,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out, used_in_ep_low_latency);
@@ -322,12 +322,12 @@ paddle::Tensor MoeExpertFFNFunc(
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn2_in_scale,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out, used_in_ep_low_latency);
@@ -341,22 +341,22 @@ paddle::Tensor MoeExpertFFNFunc(
std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn2_in_scale,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method, used_in_ep_low_latency)};
}
@@ -364,12 +364,12 @@ std::vector<paddle::Tensor> MoeExpertFFN(
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_in_scale_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
const std::string& quant_method,
const bool used_in_ep_low_latency) {
@@ -379,15 +379,15 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &ffn1_weight_dtype,
const paddle::DataType &ffn2_weight_dtype,
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_in_scale_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency) {
if (quant_method == "w4a8") {
return {ffn1_scale_dtype.get()};
return {up_gate_proj_scale_dtype.get()};
} else {
return {permute_input_dtype};
}
@@ -397,9 +397,9 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* @brief Mixture of Experts (MoE) Feed-Forward Network Operator
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (FFN1) with optional quantization
* 1. First linear transformation (up_gate_proj) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (FFN2) with optional quantization
* 3. Second linear transformation (down_proj) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
*
@@ -410,22 +410,22 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - ffn1_weight: First FFN layer weights
* - up_gate_proj_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - ffn2_weight: Second FFN layer weights
* - down_proj_weight: Second FFN layer weights
* Shape: [num_experts, hidden_size, inter_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - ffn1_bias: Optional bias for first FFN layer
* - up_gate_proj_bias: Optional bias for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - ffn1_scale: Quantization scales for first FFN layer
* - up_gate_proj_scale: Quantization scales for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - ffn2_scale: Quantization scales for second FFN layer
* - down_proj_scale: Quantization scales for second FFN layer
* Shape: [num_experts, hidden_size]
* dtype: Same as input
* - ffn2_in_scale: Optional input scales for second FFN layer (w4a8 only)
* - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 only)
* dtype: float32
* - expert_idx_per_token: Optional expert indices per token (w4a8 only)
* Shape: [total_tokens]
@@ -434,7 +434,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* Outputs:
* - output_tensor: Output tensor after MoE FFN computation
* Shape: Same as permute_input
* dtype: Same as input (or ffn1_scale dtype for w4a8)
* dtype: Same as input (or up_gate_proj_scale dtype for w4a8)
*
* Attributes:
* - quant_method: Quantization method to use
@@ -449,12 +449,12 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
PD_BUILD_STATIC_OP(moe_expert_ffn)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale"),
paddle::Optional("ffn2_in_scale"),
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),
paddle::Optional("down_proj_in_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"})

View File

@@ -23,17 +23,17 @@
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::Tensor* ffn1_bias,
const paddle::Tensor* ffn1_super_scale,
const paddle::Tensor* ffn2_super_scale,
const paddle::Tensor* ffn1_local_scale,
const paddle::Tensor* ffn1_code_scale,
const paddle::Tensor* ffn1_code_zp,
const paddle::Tensor* ffn2_local_scale,
const paddle::Tensor* ffn2_code_scale,
const paddle::Tensor* ffn2_code_zp,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::Tensor* up_gate_proj_bias,
const paddle::Tensor* up_gate_proj_super_scale,
const paddle::Tensor* down_proj_super_scale,
const paddle::Tensor* up_gate_proj_local_scale,
const paddle::Tensor* up_gate_proj_code_scale,
const paddle::Tensor* up_gate_proj_code_zp,
const paddle::Tensor* down_proj_local_scale,
const paddle::Tensor* down_proj_code_scale,
const paddle::Tensor* down_proj_code_zp,
paddle::Tensor fc1_out,
paddle::Tensor ffn_out,
const int64_t total_rows_in_ll_else_minus1,
@@ -46,15 +46,15 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
using WeightType = typename WeightOnlyTraits::WeightType;
typename WeightOnlyTraits::Arguments ffn1_quant_args;
typename WeightOnlyTraits::Arguments ffn2_quant_args;
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
typename WeightOnlyTraits::Arguments down_proj_quant_args;
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
ffn1_quant_args.local_scale_ptr = ffn1_local_scale->data<uint8_t>();
ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data<float>();
ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data<float>();
ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data<uint8_t>();
ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data<float>();
ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data<float>();
up_gate_proj_quant_args.local_scale_ptr = up_gate_proj_local_scale->data<uint8_t>();
up_gate_proj_quant_args.code_scale_ptr = up_gate_proj_code_scale->data<float>();
up_gate_proj_quant_args.code_zp_ptr = up_gate_proj_code_zp->data<float>();
down_proj_quant_args.local_scale_ptr = down_proj_local_scale->data<uint8_t>();
down_proj_quant_args.code_scale_ptr = down_proj_code_scale->data<float>();
down_proj_quant_args.code_zp_ptr = down_proj_code_zp->data<float>();
}
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
@@ -62,9 +62,9 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
reinterpret_cast<const WeightType*>(ffn1_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(ffn1_super_scale ? ffn1_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const NvType*>(ffn1_bias ? ffn1_bias->data<DataT>() : nullptr),
reinterpret_cast<const WeightType*>(up_gate_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(up_gate_proj_super_scale ? up_gate_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const NvType*>(up_gate_proj_bias ? up_gate_proj_bias->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
@@ -72,7 +72,7 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
inter_size,
hidden_size,
num_experts,
ffn1_quant_args,
up_gate_proj_quant_args,
"none",
stream);
@@ -85,8 +85,8 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
reinterpret_cast<const WeightType*>(ffn2_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(ffn2_super_scale ? ffn2_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const WeightType*>(down_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(down_proj_super_scale ? down_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
@@ -94,24 +94,24 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
ffn2_quant_args,
down_proj_quant_args,
stream);
}
template <paddle::DataType T>
void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency) {
using namespace phi;
@@ -121,12 +121,12 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
auto place = permute_input.place();
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
assert(ffn1_weight.dims().size() == 3);
assert(up_gate_proj_weight.dims().size() == 3);
const int num_experts = ffn1_weight.dims()[0];
const int num_experts = up_gate_proj_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
const int64_t inter_size = inter_dim * 4;
@@ -160,17 +160,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
WeightOnlyMoeFFNKernel<data_t, NvType, uint8_t, cutlass::WintQuantMethod::kWeightOnlyInt2>(
permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
const_cast<paddle::Tensor*>(ffn1_bias.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_code_zp.get_ptr()),
up_gate_proj_weight,
down_proj_weight,
const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_zp.get_ptr()),
fc1_out_tensor,
ffn_out,
total_rows_in_ll_else_minus1,
@@ -184,17 +184,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency) {
const auto dtype = permute_input.dtype();
@@ -204,34 +204,34 @@ paddle::Tensor MoeExpertFFNWint2Func(
case paddle::DataType::BFLOAT16:
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
case paddle::DataType::FLOAT16:
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
@@ -244,49 +244,49 @@ paddle::Tensor MoeExpertFFNWint2Func(
std::vector<paddle::Tensor> MoeExpertFFNWint2(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency) {
return {MoeExpertFFNWint2Func(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
used_in_ep_low_latency)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_code_zp_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_code_zp_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_zp_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_zp_shape,
const bool used_in_ep_low_latency) {
return {permute_input_shape};
@@ -295,17 +295,17 @@ std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &ffn1_weight_dtype,
const paddle::DataType &ffn2_weight_dtype,
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_local_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_code_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_code_zp_dtype,
const paddle::optional<paddle::DataType> &ffn2_local_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_code_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_code_zp_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_zp_dtype,
const paddle::optional<paddle::DataType> &down_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_zp_dtype,
const bool used_in_ep_low_latency) {
return {permute_input_dtype};
@@ -315,9 +315,9 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (FFN1) with optional quantization
* 1. First linear transformation (up_gate_proj) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (FFN2) with optional quantization
* 3. Second linear transformation (down_proj) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
*
@@ -328,26 +328,26 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - ffn1_weight: First FFN layer weights
* - up_gate_proj_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - ffn2_weight: Second FFN layer weights
* - down_proj_weight: Second FFN layer weights
* Shape: [num_experts, hidden_size, inter_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - ffn1_bias: Optional bias for first FFN layer
* - up_gate_proj_bias: Optional bias for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - ffn1_scale: Quantization scales for first FFN layer
* - up_gate_proj_scale: Quantization scales for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - ffn2_scale: Quantization scales for second FFN layer
* - down_proj_scale: Quantization scales for second FFN layer
* Shape: [num_experts, hidden_size]
* dtype: Same as input
*
* Outputs:
* - output_tensor: Output tensor after MoE FFN computation
* Shape: Same as permute_input
* dtype: Same as input (or ffn1_scale dtype for w4a8)
* dtype: Same as input (or up_gate_proj_scale dtype for w4a8)
*
* Attributes:
* - used_in_ep_low_latency: Whether running in low latency mode
@@ -359,17 +359,17 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
PD_BUILD_STATIC_OP(moe_expert_ffn_wint2)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale"),
paddle::Optional("ffn1_local_scale"),
paddle::Optional("ffn1_code_scale"),
paddle::Optional("ffn1_code_zp"),
paddle::Optional("ffn2_local_scale"),
paddle::Optional("ffn2_code_scale"),
paddle::Optional("ffn2_code_zp")})
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),
paddle::Optional("up_gate_proj_local_scale"),
paddle::Optional("up_gate_proj_code_scale"),
paddle::Optional("up_gate_proj_code_zp"),
paddle::Optional("down_proj_local_scale"),
paddle::Optional("down_proj_code_scale"),
paddle::Optional("down_proj_code_zp")})
.Outputs({"output_tensor"})
.Attrs({"used_in_ep_low_latency:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertFFNWint2))

View File

@@ -25,7 +25,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor, const int num_rows,
const int hidden_size, const int topk,
@@ -38,7 +38,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
finalize_moe_routing_kernelLauncher<data_t>::run(
ffn_out.data<data_t>(), output->data<data_t>(),
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(), num_rows, hidden_size, topk,
static_cast<int>(1), norm_topk_prob, routed_scaling_factor, stream);
@@ -48,7 +48,7 @@ paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
auto place = ffn_out.place();
@@ -63,13 +63,13 @@ paddle::Tensor MoeExpertReduceFunc(
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
topk, &output);
break;
case paddle::DataType::FLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
topk, &output);
break;
default:
@@ -83,10 +83,10 @@ MoeExpertReduce(const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor) {
return {MoeExpertReduceFunc(ffn_out, top_k_weight, permute_indices_per_token,
top_k_indices, ffn2_bias, norm_topk_prob,
top_k_indices, down_proj_bias, norm_topk_prob,
routed_scaling_factor)};
}
@@ -95,7 +95,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t> &top_k_weight_shape,
const std::vector<int64_t> &permute_indices_per_token_shape,
const std::vector<int64_t> &top_k_indices_shape,
const paddle::optional<std::vector<int64_t>> &ffn2_bias_shape) {
const paddle::optional<std::vector<int64_t>> &down_proj_bias_shape) {
const int moe_topk = top_k_indices_shape[1];
auto out_shape = ffn_out_shape;
if (out_shape[0] != -1) out_shape[0] /= moe_topk;
@@ -107,7 +107,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
const paddle::DataType &top_k_weight_dtype,
const paddle::DataType &permute_indices_per_token_dtype,
const paddle::DataType &top_k_indices_dtype,
const paddle::optional<paddle::DataType> &ffn2_bias_dtype) {
const paddle::optional<paddle::DataType> &down_proj_bias_dtype) {
return {ffn_out_dtype};
}
@@ -133,7 +133,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
* - top_k_indices: Indices of selected top-k experts for each token
* Shape: [total_tokens, moe_topk]
* dtype: int32
* - ffn2_bias: Optional bias term for expert outputs (hidden_size)
* - down_proj_bias: Optional bias term for expert outputs (hidden_size)
*
* Outputs:
* - output: Combined expert outputs in original token order
@@ -154,7 +154,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
*/
PD_BUILD_STATIC_OP(moe_expert_reduce)
.Inputs({"ffn_out", "top_k_weight", "permute_indices_per_token",
"top_k_indices", paddle::Optional("ffn2_bias")})
"top_k_indices", paddle::Optional("down_proj_bias")})
.Outputs({"output"})
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
.SetKernelFn(PD_KERNEL(MoeExpertReduce))

View File

@@ -25,7 +25,7 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int num_rows,
@@ -42,7 +42,7 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(),
output->data<data_t>(),
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
@@ -60,7 +60,7 @@ paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
@@ -79,7 +79,7 @@ paddle::Tensor MoeExpertReduceFunc(
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
@@ -93,7 +93,7 @@ paddle::Tensor MoeExpertReduceFunc(
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
@@ -112,14 +112,14 @@ std::vector<paddle::Tensor> MoeExpertReduce(
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
return {MoeExpertReduceFunc(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor)};
}
@@ -129,7 +129,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t>& top_k_weight_shape,
const std::vector<int64_t>& permute_indices_per_token_shape,
const std::vector<int64_t>& top_k_indices_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape) {
return {ffn_out_shape};
}
@@ -138,7 +138,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
const paddle::DataType& top_k_weight_dtype,
const paddle::DataType& permute_indices_per_token_dtype,
const paddle::DataType& top_k_indices_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
const paddle::optional<paddle::DataType>& down_proj_bias_dtype) {
return {ffn_out_dtype};
}
@@ -147,7 +147,7 @@ PD_BUILD_STATIC_OP(moe_expert_reduce)
"top_k_weight",
"permute_indices_per_token",
"top_k_indices",
paddle::Optional("ffn2_bias")})
paddle::Optional("down_proj_bias")})
.Outputs({"output"})
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
.SetKernelFn(PD_KERNEL(MoeExpertReduce))

View File

@@ -46,12 +46,12 @@ template <typename TX, typename TW>
std::vector<paddle::Tensor> MoeLayerKernel(
const paddle::Tensor &x, const paddle::Tensor &gate_weight,
const paddle::optional<paddle::Tensor> &gate_correction_bias,
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
const paddle::optional<paddle::Tensor> &ffn1_bias,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &ffn1_weight_scale,
const paddle::optional<paddle::Tensor> &ffn2_weight_scale,
const paddle::optional<paddle::Tensor> &ffn2_in_scale, // not support
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_in_scale, // not support
const std::string &quant_method, const int moe_top_k,
const bool moe_group) {
// std::cout << "[Op Debug] enter moe layer" << std::endl;
@@ -66,24 +66,24 @@ std::vector<paddle::Tensor> MoeLayerKernel(
const auto xtype = x.dtype();
auto x_dims = x.shape();
auto ffn1_dims = ffn1_weight.shape();
auto up_gate_proj_dims = up_gate_proj_weight.shape();
PD_CHECK(x_dims.size() == 2, "x_dims.size() shoud be 2.");
PD_CHECK(ffn1_dims.size() == 3, "ffn1_dims.size() should be 3.");
PD_CHECK(ffn2_in_scale.get_ptr() == nullptr, "ffn2_in_scale not support.");
PD_CHECK(up_gate_proj_dims.size() == 3, "up_gate_proj_dims.size() should be 3.");
PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, "down_proj_in_scale not support.");
if (quant_method == "weight_only_int4") {
PD_CHECK(x_dims[1] == ffn1_dims[2] * 2,
"x_dims[1] should equal to ffn1_dims[2], (weight must be "
PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2,
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
"[e,n,k]).");
} else {
PD_CHECK(x_dims[1] == ffn1_dims[2],
"x_dims[1] should equal to ffn1_dims[2], (weight must be "
PD_CHECK(x_dims[1] == up_gate_proj_dims[2],
"x_dims[1] should equal to up_gate_proj_dims[2], (weight must be "
"[e,n,k]).");
}
int token_num = x_dims[0];
int hidden_dim = x_dims[1];
int expert_num = ffn1_dims[0];
int inter_dim = ffn1_dims[1];
int expert_num = up_gate_proj_dims[0];
int inter_dim = up_gate_proj_dims[1];
int outer_dim = inter_dim / 2;
paddle::Tensor fused_moe_out = paddle::empty_like(x);
@@ -118,63 +118,63 @@ std::vector<paddle::Tensor> MoeLayerKernel(
gate_correction_bias.get_ptr()->shape());
}
// ffn1 + ffn2
std::shared_ptr<xftblock::Tensor> xffn1_w, xffn2_w;
// up_gate_proj + down_proj
std::shared_ptr<xftblock::Tensor> xup_gate_proj_w, xdown_proj_w;
if (std::is_same<TW, int4_t>::value) {
xffn1_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(ffn1_weight.data<int8_t>()), nullptr,
const_cast<float *>(ffn1_weight_scale.get_ptr()
? ffn1_weight_scale.get_ptr()->data<float>()
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(up_gate_proj_weight.data<int8_t>()), nullptr,
const_cast<float *>(up_gate_proj_weight_scale.get_ptr()
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
xffn2_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(ffn2_weight.data<int8_t>()), nullptr,
const_cast<float *>(ffn2_weight_scale.get_ptr()
? ffn2_weight_scale.get_ptr()->data<float>()
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<int8_t *>(down_proj_weight.data<int8_t>()), nullptr,
const_cast<float *>(down_proj_weight_scale.get_ptr()
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
} else {
xffn1_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(ffn1_weight.data<TW>()), nullptr,
const_cast<float *>(ffn1_weight_scale.get_ptr()
? ffn1_weight_scale.get_ptr()->data<float>()
xup_gate_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(up_gate_proj_weight.data<TW>()), nullptr,
const_cast<float *>(up_gate_proj_weight_scale.get_ptr()
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
xffn2_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(ffn2_weight.data<TW>()), nullptr,
const_cast<float *>(ffn2_weight_scale.get_ptr()
? ffn2_weight_scale.get_ptr()->data<float>()
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
const_cast<float *>(down_proj_weight_scale.get_ptr()
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
}
std::shared_ptr<xftblock::Tensor> xffn1_bias;
std::shared_ptr<xftblock::Tensor> xffn2_bias;
if (ffn1_bias.get_ptr()) {
xffn1_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(ffn1_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT, ffn1_bias.get_ptr()->shape());
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
if (up_gate_proj_bias.get_ptr()) {
xup_gate_proj_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(up_gate_proj_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT, up_gate_proj_bias.get_ptr()->shape());
}
if (ffn2_bias.get_ptr()) {
xffn2_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(ffn2_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT, ffn2_bias.get_ptr()->shape());
if (down_proj_bias.get_ptr()) {
xdown_proj_bias = std::make_shared<xftblock::Tensor>(
const_cast<float *>(down_proj_bias.get_ptr()->data<float>()),
xftblock::DataType::DT_FLOAT, down_proj_bias.get_ptr()->shape());
}
// std::cout << "[Op Debug] start init moe_ffn weight and bias" <<
// std::endl; MoeFFNWeight
xftblock::MoeFFNWeight moe_ffn_w_struct;
moe_ffn_w_struct.gate_weight = &xgate_w;
moe_ffn_w_struct.ffn_inter_weights = xffn1_w.get();
moe_ffn_w_struct.ffn_inter_bias = xffn1_bias.get();
moe_ffn_w_struct.ffn_outer_weights = xffn2_w.get();
moe_ffn_w_struct.ffn_outer_bias = xffn2_bias.get();
moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get();
moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get();
moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get();
moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get();
moe_ffn_w_struct.score_bias = xgate_correct_bias.get();
// MoeFFNParam
xftblock::MoeFFNParam moe_ffn_param;
@@ -198,22 +198,22 @@ std::vector<paddle::Tensor> MoeLayerKernel(
std::vector<paddle::Tensor>
MoeLayer(const paddle::Tensor &x, const paddle::Tensor &gate_weight,
const paddle::optional<paddle::Tensor> &gate_correction_bias,
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
const paddle::optional<paddle::Tensor> &ffn1_bias,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &ffn1_weight_scale,
const paddle::optional<paddle::Tensor> &ffn2_weight_scale,
const paddle::optional<paddle::Tensor> &ffn2_in_scale,
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_weight_scale,
const paddle::optional<paddle::Tensor> &down_proj_in_scale,
const std::string &quant_method, const int moe_top_k,
const bool moe_group) {
const auto x_type = x.dtype();
const auto w_type = ffn1_weight.dtype();
const auto w_type = up_gate_proj_weight.dtype();
#define APPLY_MOE_LAYER_KERNEL(TX, TW) \
return MoeLayerKernel<TX, TW>( \
x, gate_weight, gate_correction_bias, ffn1_weight, ffn2_weight, \
ffn1_bias, ffn2_bias, ffn1_weight_scale, ffn2_weight_scale, \
ffn2_in_scale, quant_method, moe_top_k, moe_group);
x, gate_weight, gate_correction_bias, up_gate_proj_weight, down_proj_weight, \
up_gate_proj_bias, down_proj_bias, up_gate_proj_weight_scale, down_proj_weight_scale, \
down_proj_in_scale, quant_method, moe_top_k, moe_group);
// TODO(mayang02): how to use quant_method?
if (x_type == paddle::DataType::BFLOAT16 &&
@@ -237,36 +237,36 @@ std::vector<std::vector<int64_t>> MoeLayerInferShape(
const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &gate_weight_shape,
const paddle::optional<std::vector<int64_t>> &gate_correction_bias_shape,
const std::vector<int64_t> &ffn1_weight_shape,
const std::vector<int64_t> &ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>> &ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>> &ffn2_bias_shape,
const paddle::optional<std::vector<int64_t>> &ffn1_weight_scale_shape,
const paddle::optional<std::vector<int64_t>> &ffn2_weight_scale_shape,
const paddle::optional<std::vector<int64_t>> &ffn2_in_scale_shape) {
const std::vector<int64_t> &up_gate_proj_weight_shape,
const std::vector<int64_t> &down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>> &up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>> &down_proj_bias_shape,
const paddle::optional<std::vector<int64_t>> &up_gate_proj_weight_scale_shape,
const paddle::optional<std::vector<int64_t>> &down_proj_weight_scale_shape,
const paddle::optional<std::vector<int64_t>> &down_proj_in_scale_shape) {
return {x_shape};
}
std::vector<paddle::DataType> MoeLayerInferDtype(
const paddle::DataType &x_dtype, const paddle::DataType &gate_weight_dtype,
const paddle::optional<paddle::DataType> &gate_correction_bias_dtype,
const paddle::DataType &ffn1_weight_dtype,
const paddle::DataType &ffn2_weight_dtype,
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
const paddle::optional<paddle::DataType> &ffn2_bias_dtype,
const paddle::optional<paddle::DataType> &ffn1_weight_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_weight_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_in_scale_dtype) {
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &down_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_weight_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_weight_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype) {
return {x_dtype};
}
PD_BUILD_OP(xpu_moe_layer) // fused_moe
.Inputs({"x", "gate_weight", paddle::Optional("gate_correction_bias"),
"ffn1_weight", "ffn2_weight", paddle::Optional("ffn1_bias"),
paddle::Optional("ffn2_bias"),
paddle::Optional("ffn1_weight_scale"),
paddle::Optional("ffn2_weight_scale"),
paddle::Optional("ffn2_in_scale")})
"up_gate_proj_weight", "down_proj_weight", paddle::Optional("up_gate_proj_bias"),
paddle::Optional("down_proj_bias"),
paddle::Optional("up_gate_proj_weight_scale"),
paddle::Optional("down_proj_weight_scale"),
paddle::Optional("down_proj_in_scale")})
.Outputs({"fused_moe_out"})
.Attrs({"quant_method:std::string", "moe_top_k:int", "moe_group:bool"})
.SetKernelFn(PD_KERNEL(MoeLayer))

View File

@@ -19,12 +19,10 @@ from paddle import nn
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
get_tensor)
from fastdeploy.model_executor.layers.quantization.quant_base import \
QuantMethodBase
from fastdeploy.utils import ceil_div
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
@@ -36,9 +34,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
Triton Group Gemm to compute Fused MoE.
"""
self.quant_method = quant_method
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
@@ -49,26 +47,26 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert self.quant_method.name() == "wint8"
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
if self.quant_method.name() == "wint8":
max_bound = 127
elif self.quant_method.name() == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -150,10 +148,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x,
layer.moe_ffn1_weight,
layer.up_gate_proj_weight,
intermediate_cache1,
None,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -164,17 +162,17 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
token_num * top_k,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
@@ -197,10 +195,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
layer.down_proj_weight,
intermediate_cache3,
None,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -211,16 +209,16 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
token_num * top_k,
stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters

View File

@@ -16,8 +16,8 @@
import paddle
from paddle.nn.quant import weight_dequantize
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, GPUWeightOnlyLinearMethod
from fastdeploy.model_executor.layers.quantization.weight_only import (
GPUWeightOnlyLinearMethod, WeightOnlyConfig)
class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
@@ -35,12 +35,12 @@ class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
def apply(self, layer, x):
dequant_out = weight_dequantize(
x=layer.linear_weight,
scale=layer.linear_weight_scale,
x=layer.weight,
scale=layer.weight_scale,
algo=self.quant_config.algo,
out_dtype=paddle.get_default_dtype()
)
linear_out = paddle.matmul(x, dequant_out)
if layer.linear_bias is not None:
linear_out = paddle.add(linear_out, layer.linear_bias)
if layer.bias is not None:
linear_out = paddle.add(linear_out, layer.bias)
return linear_out

View File

@@ -50,11 +50,11 @@ class GCUFusedMoeMethod(MoEMethodBase):
Paddle gcu create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
# shape [E, K, N] -> [E, N, K]
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
weight_name = self.added_weight_attrs[idx]
@@ -117,16 +117,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
dtype=x.dtype,
)
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
up_gate_proj_B_scale = layer.up_gate_proj_weight_scale if enable_quant else None
up_gate_proj_B_zeros = layer.up_gate_proj_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
x, # input
layer.moe_ffn1_weight, # weight
layer.up_gate_proj_weight, # weight
intermediate_cache1, # output
None, # A_scale
ffn1_B_scale, # B_scale
ffn1_B_zeros, # B_zp
up_gate_proj_B_scale, # B_scale
up_gate_proj_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
@@ -154,16 +154,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
dtype=x.dtype,
)
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
down_proj_B_scale = layer.down_proj_weight_scale if enable_quant else None
down_proj_B_zeros = layer.down_proj_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
intermediate_cache2, # input
layer.moe_ffn2_weight, # weight
layer.down_proj_weight, # weight
intermediate_cache3, # output
None, # A_scale
ffn2_B_scale, # B_scale
ffn2_B_zeros, # B_zp
down_proj_B_scale, # B_scale
down_proj_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
@@ -251,7 +251,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
self.added_qzeros_attrs = [
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
"up_gate_proj_weight_zeros", "down_proj_weight_zeros"
]
self.group_size = 64
@@ -265,41 +265,41 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
Paddle gcu process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
down_proj_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -310,8 +310,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
@@ -329,7 +329,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
zeros_name = self.added_qzeros_attrs[idx]
@@ -365,8 +365,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
dict_ = dict(shared_dict)
for k, v in dict_.items():
weight_list[k] = v[0].to(ffn1_weights[0].place)
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
weight_list[k] = v[0].to(up_gate_proj_weights[0].place)
weight_scale_list[k] = v[1].to(up_gate_proj_weights[0].place)
else:
remain_weights_start_idx = 0

View File

@@ -38,14 +38,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
weight_scale_shape = [layer.weight_shape[1]]
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
@@ -61,8 +61,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.linear_weight.set_value(quant_weight)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quant_weight)
layer.weight_scale.set_value(
weight_scale.astype(paddle.get_default_dtype()))
@@ -73,8 +73,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
self.group_size, # group_size
)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(
weight_scale_tensor.astype(paddle.get_default_dtype()))
@@ -82,8 +82,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
def apply(self, layer, x):
linear_out = linear_quant(
lhs=x,
rhs=layer.linear_weight,
scale=layer.linear_weight_scale,
rhs=layer.weight,
scale=layer.weight_scale,
bias=None,
group_size=self.group_size,
)

View File

@@ -37,13 +37,13 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
Create weights for linear layer on XPU
"""
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse()
weight_scale_shape = [layer.weight_shape[1]]
layer.weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4":
layer.linear_weight_shape[0] //= 2
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype="float32",
is_bias=False,
)
@@ -55,6 +55,6 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
weight, self.quant_config.algo, -1, -1)
layer.linear_weight.set_value(
layer.weight.set_value(
paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.linear_weight_scale.set_value(weight_scale_tensor)
layer.weight_scale.set_value(weight_scale_tensor)

View File

@@ -68,13 +68,13 @@ class VocabParallelEmbedding(nn.Layer):
self.params_dtype: str = params_dtype
if self.use_ep:
self.word_embeddings = nn.Embedding(
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim,
)
else:
if not self.column_cut:
self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
embedding_dim,
mp_group=fleet.get_hybrid_communicate_group().
@@ -85,13 +85,13 @@ class VocabParallelEmbedding(nn.Layer):
)
else:
# column cut embedding
self.word_embeddings = nn.Embedding(
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim // self.world_size,
)
self.word_embeddings.weight.is_distributed = True
self.word_embeddings.weight.split_axis = 1
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if not self.use_rope:
self.position_embeddings = nn.Embedding(
@@ -112,13 +112,12 @@ class VocabParallelEmbedding(nn.Layer):
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
a = state_dict[self.prefix + ".weight"]
if self.tie_word_embeddings:
self.word_embeddings.weight.set_value(
self.embeddings.weight.set_value(
get_tensor(state_dict[self.prefix + ".weight"]).astype(
paddle.get_default_dtype()))
else:
self.word_embeddings.weight.set_value(
self.embeddings.weight.set_value(
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
paddle.get_default_dtype()))
@@ -134,10 +133,10 @@ class VocabParallelEmbedding(nn.Layer):
Tensor: Embedded tensor representation of the input IDs.
"""
if self.use_ep:
input_embedings = self.word_embeddings(ids_remove_padding)
input_embedings = self.embeddings(ids_remove_padding)
else:
if self.column_cut:
input_embedings = self.word_embeddings(ids_remove_padding)
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []
paddle.distributed.all_gather(
inputs_embeds_temp,
@@ -148,6 +147,6 @@ class VocabParallelEmbedding(nn.Layer):
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
input_embedings = self.word_embeddings(ids_remove_padding)
input_embedings = self.embeddings(ids_remove_padding)
return input_embedings

View File

@@ -14,16 +14,13 @@
# limitations under the License.
"""
from paddleformers.utils.log import logger
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
VocabParallelEmbedding,
)
from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear,
VocabParallelEmbedding)
from paddleformers.utils.log import logger
from .utils import get_tensor
@@ -130,7 +127,7 @@ class HydraHead(nn.Layer):
]
)
self.word_embeddings = VocabParallelEmbedding(
self.embeddings = VocabParallelEmbedding(
vocab_size,
hidden_size,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
@@ -170,8 +167,8 @@ class HydraHead(nn.Layer):
get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight"))
)
self.word_embeddings.weight.set_value(
get_tensor(state_dict.pop("word_embeddings.weight"))
self.embeddings.weight.set_value(
get_tensor(state_dict.pop("embeddings.weight"))
)
def set_state_dict(self, state_dict):
@@ -183,7 +180,7 @@ class HydraHead(nn.Layer):
"""
is_custom = True
for key in state_dict.keys():
if key != "word_embeddings.weight" and (
if key != "embeddings.weight" and (
"hydra_mlp" in key or "hydra_head" in key
):
is_custom = False
@@ -207,7 +204,7 @@ class HydraHead(nn.Layer):
hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens
"""
hydra_inputs = [hidden_states]
input_embeds = self.word_embeddings(input_ids)
input_embeds = self.embeddings(input_ids)
for hydra_head_idx in range(self.hydra_num_heads):
hydra_inputs.append(input_embeds)
head_input = paddle.concat(hydra_inputs, axis=-1)
@@ -217,4 +214,4 @@ class HydraHead(nn.Layer):
_, topk_tokens = paddle.topk(probs, k=1, axis=-1)
next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:]
input_embeds = self.word_embeddings(next_tokens[:, 1 + hydra_head_idx])
input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx])

View File

@@ -79,7 +79,7 @@ class LinearBase(nn.Layer):
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -96,16 +96,16 @@ class LinearBase(nn.Layer):
"""
if self.skip_quant:
self.weight_dtype = self._dtype
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.linear_bias = None
self.bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
@@ -136,7 +136,7 @@ class LinearBase(nn.Layer):
if self.fd_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.linear_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
@@ -157,7 +157,7 @@ class LinearBase(nn.Layer):
if self.with_bias:
bias_tensor = paddle.to_tensor(
get_tensor(state_dict.pop(self.bias_key)))
self.linear_bias.set_value(bias_tensor)
self.bias.set_value(bias_tensor)
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
"""
@@ -175,9 +175,9 @@ class LinearBase(nn.Layer):
if self.fd_config.quant_config:
linear_out = self.quant_method.apply(self, x)
else:
linear_out = paddle.matmul(x, self.linear_weight)
linear_out = paddle.matmul(x, self.weight)
if self.with_bias:
linear_out = paddle.add(linear_out, self.linear_bias)
linear_out = paddle.add(linear_out, self.bias)
return linear_out
@@ -219,7 +219,7 @@ class ReplicatedLinear(LinearBase):
skip_quant=skip_quant)
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -272,7 +272,7 @@ class ColumnParallelLinear(LinearBase):
output_size,
self.nranks) # Split the output_size using TP inference.
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -286,26 +286,26 @@ class ColumnParallelLinear(LinearBase):
"""
if self.skip_quant:
self.weight_dtype = self._dtype
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_weight, split_axis=1)
_set_var_distributed(self.weight, split_axis=1)
self.linear_bias = None
self.bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_bias, split_axis=1)
_set_var_distributed(self.bias, split_axis=1)
# smooth quant
self.linear_shift = None
@@ -333,7 +333,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_quant: bool = False,
):
"""
Initialize the fused ffn1 Linear layer with given parameters.
Initialize the fused up_gate_proj Linear layer with given parameters.
Args:
fd_config (FDConfig): Inference-related parameters.
@@ -462,7 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear):
if self.fd_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.linear_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
@@ -485,7 +485,7 @@ class QKVParallelLinear(ColumnParallelLinear):
if self.bias_key in state_dict.keys():
bias_tensor = paddle.to_tensor(
get_tensor(state_dict.pop(self.bias_key)))
self.linear_bias.set_value(bias_tensor)
self.bias.set_value(bias_tensor)
else:
q_bias_key = self.bias_key.replace("qkv_proj", "q_proj")
k_bias_key = self.bias_key.replace("qkv_proj", "k_proj")
@@ -494,7 +494,7 @@ class QKVParallelLinear(ColumnParallelLinear):
k_bias = get_tensor(state_dict.pop(k_bias_key))
v_bias = get_tensor(state_dict.pop(v_bias_key))
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
self.linear_bias.set_value(qkv_bias)
self.bias.set_value(qkv_bias)
class RowParallelLinear(LinearBase):
@@ -554,7 +554,7 @@ class RowParallelLinear(LinearBase):
self.input_size = divide(input_size, self.nranks)
self.output_size = output_size
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -574,16 +574,16 @@ class RowParallelLinear(LinearBase):
if self.skip_quant:
self.weight_dtype = self._dtype
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.linear_bias = None
self.bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.hidden_size],
dtype=self._dtype,
is_bias=True,
@@ -591,7 +591,7 @@ class RowParallelLinear(LinearBase):
if self.nranks > 0:
# row parallel
_set_var_distributed(self.linear_weight, split_axis=0)
_set_var_distributed(self.weight, split_axis=0)
# smooth quant
self.linear_shift = None
@@ -601,7 +601,7 @@ class RowParallelLinear(LinearBase):
if self.fd_config.quant_config:
out = self.quant_method.apply(self, x)
else:
out = paddle.matmul(x, self.linear_weight)
out = paddle.matmul(x, self.weight)
if self.reduce_results and self.nranks > 1:
tensor_model_parallel_all_reduce(out)

View File

@@ -52,11 +52,11 @@ class ParallelLMHead(nn.Layer):
with_bias (bool): whether to have bias. Default: False.
"""
super(ParallelLMHead, self).__init__()
self.linear_weight_key: str = prefix + ".weight"
self.weight_key: str = prefix + ".weight"
if with_bias:
self.linear_bias_key: Optional[str] = prefix + ".bias"
self.bias_key: Optional[str] = prefix + ".bias"
else:
self.linear_bias_key: Optional[str] = None
self.bias_key: Optional[str] = None
self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True
@@ -74,26 +74,26 @@ class ParallelLMHead(nn.Layer):
else:
if self.column_cut:
need_gather = True
self.out_linear = ColumnParallelLinear(
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
else:
self.out_linear = RowParallelLinear(
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
@@ -109,25 +109,25 @@ class ParallelLMHead(nn.Layer):
if self.use_ep:
self.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()))
else:
if self.tie_word_embeddings:
self.out_linear.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()).transpose([1, 0]))
else:
weight_tensor = get_tensor(
state_dict.pop(self.linear_weight_key)).astype(
state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype())
if self.out_linear.weight.shape != weight_tensor.shape:
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.out_linear.weight.set_value(weight_tensor)
self.linear.weight.set_value(weight_tensor)
if self.linear_bias_key is not None:
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
paddle.get_default_dtype())
self.out_linear.bias.set_value(bias)
self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
"""
@@ -143,5 +143,5 @@ class ParallelLMHead(nn.Layer):
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.out_linear(logits)
logits = self.linear(logits)
return logits

View File

@@ -34,9 +34,9 @@ class MoEMethodBase(QuantMethodBase):
self.moe_quant_type = "w16a16"
else:
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
self.pack_num = 1
@@ -63,14 +63,14 @@ class MoEMethodBase(QuantMethodBase):
"""
pass
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
check layer is valid for this method
"""
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size // self.pack_num, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size // self.pack_num, layer.hidden_size
]

View File

@@ -31,7 +31,8 @@ if current_platform.is_cuda() and not current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
moe_expert_reduce, noaux_tc)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import moe_expert_dispatch, moe_expert_reduce
from fastdeploy.model_executor.ops.iluvatar import (moe_expert_dispatch,
moe_expert_reduce)
# used for deepseek_v3
@@ -65,11 +66,11 @@ class CutlassMoEMethod(MoEMethodBase):
Paddle cutlass create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
setattr(
layer, weight_name,
@@ -95,15 +96,15 @@ class CutlassMoEMethod(MoEMethodBase):
return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn(
permute_input,
token_nums_per_expert,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None,
(layer.moe_ffn1_weight_scale if hasattr(
layer, "moe_ffn1_weight_scale") else None),
(layer.moe_ffn2_weight_scale if hasattr(
layer, "moe_ffn2_weight_scale") else None),
(layer.moe_ffn2_in_scale
if hasattr(layer, "moe_ffn2_in_scale") else None),
(layer.up_gate_proj_weight_scale if hasattr(
layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(
layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale
if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
@@ -111,15 +112,15 @@ class CutlassMoEMethod(MoEMethodBase):
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None,
(layer.moe_ffn1_weight_scale
if hasattr(layer, "moe_ffn1_weight_scale") else None),
(layer.moe_ffn2_weight_scale
if hasattr(layer, "moe_ffn2_weight_scale") else None),
(layer.moe_ffn2_in_scale
if hasattr(layer, "moe_ffn2_in_scale") else None),
(layer.up_gate_proj_weight_scale
if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale
if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale
if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
@@ -163,8 +164,8 @@ class CutlassMoEMethod(MoEMethodBase):
recv_x,
recv_topk_idx,
recv_topk_weights,
(self.moe_ffn1_in_scale
if hasattr(self, "moe_ffn1_in_scale") else None),
(self.up_gate_proj_in_scale
if hasattr(self, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
@@ -186,7 +187,7 @@ class CutlassMoEMethod(MoEMethodBase):
dst_weights,
permute_indices_per_token,
dst_indices,
None, # moe_ffn2_bias,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)[0]
@@ -256,7 +257,7 @@ class CutlassMoEMethod(MoEMethodBase):
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
@@ -274,7 +275,7 @@ class CutlassMoEMethod(MoEMethodBase):
x,
gate_out,
layer.gate_correction_bias,
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
@@ -323,9 +324,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
weight_list = []
for i in range(layer.num_local_experts):
@@ -366,26 +367,26 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
create_and_set_parameter(layer, name, processed_weight_scale)
# 1. Init scale containers and maps
moe_ffn1_weight_scales = []
moe_ffn2_weight_scales = []
moe_ffn1_in_scales = []
moe_ffn2_in_scales = []
up_gate_proj_weight_scales = []
down_proj_weight_scales = []
up_gate_proj_in_scales = []
down_proj_in_scales = []
scale_weight_map = {
"moe_ffn1_weight_scale": moe_ffn1_weight_scales,
"moe_ffn2_weight_scale": moe_ffn2_weight_scales,
"moe_ffn1_in_scale": moe_ffn1_in_scales,
"moe_ffn2_in_scale": moe_ffn2_in_scales,
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
"down_proj_weight_scale": down_proj_weight_scales,
"up_gate_proj_in_scale": up_gate_proj_in_scales,
"down_proj_in_scale": down_proj_in_scales,
}
scale_key_map = {
"moe_ffn1_weight_scale":
weight_key_map.get("ffn1_expert_weight_scale_key", None),
"moe_ffn2_weight_scale":
weight_key_map.get("ffn2_expert_weight_scale_key", None),
"moe_ffn1_in_scale":
weight_key_map.get("ffn1_expert_in_scale_key", None),
"moe_ffn2_in_scale":
weight_key_map.get("ffn2_expert_in_scale_key", None),
"up_gate_proj_weight_scale":
weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
"down_proj_weight_scale":
weight_key_map.get("down_proj_expert_weight_scale_key", None),
"up_gate_proj_in_scale":
weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
"down_proj_in_scale":
weight_key_map.get("down_proj_expert_in_scale_key", None),
}
for name, value in scale_key_map.items():
if value is None:
@@ -404,13 +405,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
# 3. Process scale tensor and set to layer
in_scales = []
for in_scale_name in ["moe_ffn1_in_scale", "moe_ffn2_in_scale"]:
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
in_scales.append(
_process_in_scale(in_scale_name,
scale_weight_map[in_scale_name]))
for i, weight_scale_name in enumerate(
["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]):
["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
_process_weight_scale(weight_scale_name,
scale_weight_map[weight_scale_name],
in_scales[i])
@@ -431,41 +432,41 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
"""
Paddle cutlass process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_local_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
down_proj_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -474,10 +475,10 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]

View File

@@ -39,11 +39,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deepgemm create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -70,41 +70,41 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
"""
Paddle cutlass process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_local_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
down_proj_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
ffn2_weight = paddle.stack(ffn2_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
down_proj_weight = paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -171,21 +171,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
# ffn1
# up_gate_proj
ffn_out = paddle.empty(
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale),
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
ffn_out,
m_indices,
)
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
# ffn2
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0])
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose(
@@ -193,11 +193,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty(
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
dtype=paddle.bfloat16)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
(layer.down_proj_weight, layer.down_proj_weight_scale),
ffn_out,
m_indices,
)
@@ -207,7 +207,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
dst_weights,
permute_indices_per_token,
dst_indices,
None, # moe_ffn2_bias
None, # down_proj_bias
False, # norm_topk_prob
1.0,
)[0]
@@ -237,7 +237,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# 3. Compute ffn
assert isinstance(permute_input, tuple)
ffn1_out = paddle.empty(
up_gate_proj_out = paddle.empty(
[
layer.num_local_experts,
layer.ep_size *
@@ -261,16 +261,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
permute_input,
(
layer.moe_ffn1_weight,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight,
layer.up_gate_proj_weight_scale,
),
ffn1_out,
up_gate_proj_out,
token_nums_per_expert,
expected_m,
)
act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked(
ffn1_out, token_nums_per_expert)
up_gate_proj_out, token_nums_per_expert)
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant(
act_out, token_nums_per_expert,
@@ -279,8 +279,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(act_out_fp8, scale),
(
layer.moe_ffn2_weight,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight,
layer.down_proj_weight_scale,
),
ffn_out,
token_nums_per_expert,
@@ -339,21 +339,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
# ffn1
# up_gate_proj
ffn_out = paddle.empty(
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale),
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
ffn_out,
m_indices,
)
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
# ffn2
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0])
@@ -362,11 +362,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty(
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
dtype=paddle.bfloat16)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
(layer.down_proj_weight, layer.down_proj_weight_scale),
ffn_out,
m_indices,
)

View File

@@ -103,9 +103,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
Marlin Group Gemm to compute Fused MoE.
"""
self.quant_method = quant_method
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
self.added_zeros_attrs = ["zeros0", "zeros1"]
@@ -113,22 +113,22 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
"""
Marlin MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
assert ffn1_weights[0].shape == [
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -221,8 +221,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
ffn_out = MoeWna16MarlinGemmApi(
x,
c_or_none=None,
b_q_weight=layer.moe_ffn1_weight,
b_scales=layer.moe_ffn1_weight_scale,
b_q_weight=layer.up_gate_proj_weight,
b_scales=layer.up_gate_proj_weight_scale,
global_scale_or_none=None,
b_zeros_or_none=None,
g_idx_or_none=None,
@@ -250,8 +250,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
ffn_out = MoeWna16MarlinGemmApi(
swiglu_out,
c_or_none=None,
b_q_weight=layer.moe_ffn2_weight,
b_scales=layer.moe_ffn2_weight_scale,
b_q_weight=layer.down_proj_weight,
b_scales=layer.down_proj_weight_scale,
global_scale_or_none=None,
b_zeros_or_none=None,
g_idx_or_none=None,

View File

@@ -30,7 +30,7 @@ try:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
from .triton_moe_kernels import fused_moe_kernel_paddle
except:
except ImportError:
pass
@@ -44,9 +44,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
@@ -57,30 +57,30 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
if algo == "wint8":
max_bound = 127
elif algo == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -130,7 +130,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
True, # apply_norm_weight,
False,
)
ffn1_out = paddle.empty(
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
@@ -150,10 +150,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x,
layer.moe_ffn1_weight,
ffn1_out,
layer.up_gate_proj_weight,
up_gate_proj_out,
None,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -164,17 +164,17 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_cm=ffn1_out.strides[0],
stride_cn=ffn1_out.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
@@ -190,10 +190,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
ffn2_input = paddle.incubate.nn.functional.swiglu(
ffn1_out)
down_proj_input = paddle.incubate.nn.functional.swiglu(
up_gate_proj_out)
ffn2_out = paddle.empty(
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
@@ -202,11 +202,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
ffn2_input,
layer.moe_ffn2_weight,
ffn2_out,
down_proj_input,
layer.down_proj_weight,
down_proj_out,
None,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -215,18 +215,18 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=ffn2_input.strides[0],
stride_ak=ffn2_input.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_cm=ffn2_out.strides[0],
stride_cn=ffn2_out.strides[1],
stride_am=down_proj_input.strides[0],
stride_ak=down_proj_input.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
@@ -242,8 +242,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
ffn2_out.reshape_([token_num, top_k, hidden_size])
out = ffn2_out.sum(axis=1)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
return out
@@ -261,20 +261,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
ffn1_tensor, ffn2_tensor = layer.extract_moe_ffn_weights(state_dict)
assert ffn1_tensor[0].shape == [
up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict)
assert up_gate_proj_tensor[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_tensor[0].shape == [
assert down_proj_tensor[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_tensor, axis=0).view(paddle.float8_e4m3fn)
ffn2_tensor = paddle.stack(ffn2_tensor, axis=0).view(paddle.float8_e4m3fn)
up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
added_wfp8afp8_attrs = [
"moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale",
"moe_ffn2_weight_scale", "moe_ffn1_in_scale", "moe_ffn2_in_scale"
"up_gate_proj_weight", "down_proj_weight", "up_gate_proj_weight_scale",
"down_proj_weight_scale", "up_gate_proj_in_scale", "down_proj_in_scale"
]
def _extract_scale_tensor(key_template):
@@ -285,18 +285,18 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
return paddle.concat(result).cast("float32")
weight_key_map = layer.weight_key_map
moe_ffn1_weight_scale = _extract_scale_tensor(
weight_key_map["ffn1_expert_weight_scale_key"])
moe_ffn2_weight_scale = _extract_scale_tensor(
weight_key_map["ffn2_expert_weight_scale_key"])
moe_ffn1_in_scale = _extract_scale_tensor(
weight_key_map["ffn1_expert_in_scale_key"])
moe_ffn2_in_scale = _extract_scale_tensor(
weight_key_map["ffn2_expert_in_scale_key"])
up_gate_proj_weight_scale = _extract_scale_tensor(
weight_key_map["up_gate_proj_expert_weight_scale_key"])
down_proj_weight_scale = _extract_scale_tensor(
weight_key_map["down_proj_expert_weight_scale_key"])
up_gate_proj_in_scale = _extract_scale_tensor(
weight_key_map["up_gate_proj_expert_in_scale_key"])
down_proj_in_scale = _extract_scale_tensor(
weight_key_map["down_proj_expert_in_scale_key"])
for idx, weight_tensor in enumerate([
ffn1_tensor, ffn2_tensor, moe_ffn1_weight_scale,
moe_ffn2_weight_scale, moe_ffn1_in_scale, moe_ffn2_in_scale
up_gate_proj_tensor, down_proj_tensor, up_gate_proj_weight_scale,
down_proj_weight_scale, up_gate_proj_in_scale, down_proj_in_scale
]):
name = added_wfp8afp8_attrs[idx]
setattr(
@@ -341,12 +341,12 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
False,
)
ffn1_out = paddle.empty(
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
config_ffn1 = {
config_up_gate_proj = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
@@ -354,15 +354,15 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config_ffn1["BLOCK_SIZE_M"])
topk_ids, num_local_experts, config_up_gate_proj["BLOCK_SIZE_M"])
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config_ffn1["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config_ffn1["BLOCK_SIZE_N"]), )
ceil_div(max_possible_num_post_padded, config_up_gate_proj["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config_up_gate_proj["BLOCK_SIZE_N"]), )
permute_x = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
x,
scale=layer.moe_ffn1_in_scale,
scale=layer.up_gate_proj_in_scale,
topk_ids=topk_ids,
top_k=top_k,
intermediate_size=hidden_size,
@@ -370,10 +370,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
permute_x,
layer.moe_ffn1_weight,
ffn1_out,
layer.moe_ffn1_in_scale,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight,
up_gate_proj_out,
layer.up_gate_proj_in_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -384,11 +384,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_cm=ffn1_out.strides[0],
stride_cn=ffn1_out.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=-1, # only used in blockwise fp8
stride_ask=-1, # only used in blockwise fp8
@@ -398,51 +398,51 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config_ffn1["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_ffn1["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_ffn1["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_ffn1["GROUP_SIZE_M"],
BLOCK_SIZE_M=config_up_gate_proj["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_up_gate_proj["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_up_gate_proj["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_up_gate_proj["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0,
)
ffn2_input = paddle.incubate.nn.functional.swiglu(
ffn1_out)
down_proj_input = paddle.incubate.nn.functional.swiglu(
up_gate_proj_out)
ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
ffn2_input,
scale=layer.moe_ffn2_in_scale,
down_proj_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
down_proj_input,
scale=layer.down_proj_in_scale,
topk_ids=topk_ids,
top_k=top_k,
intermediate_size=moe_intermediate_size,
tiled=True)
config_ffn2 = {
config_down_proj = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
ffn2_out = paddle.empty(
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
ceil_div(max_possible_num_post_padded, config_down_proj["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config_down_proj["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
ffn2_input,
layer.moe_ffn2_weight,
ffn2_out,
layer.moe_ffn2_in_scale,
layer.moe_ffn2_weight_scale,
down_proj_input,
layer.down_proj_weight,
down_proj_out,
layer.down_proj_in_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -451,13 +451,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=ffn2_input.strides[0],
stride_ak=ffn2_input.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_cm=ffn2_out.strides[0],
stride_cn=ffn2_out.strides[1],
stride_am=down_proj_input.strides[0],
stride_ak=down_proj_input.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=-1,
@@ -466,20 +466,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config_ffn2["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_ffn2["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_ffn2["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_ffn2["GROUP_SIZE_M"],
BLOCK_SIZE_M=config_down_proj["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_down_proj["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_down_proj["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_down_proj["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0,
)
ffn2_out.reshape_([token_num, top_k, hidden_size])
out = ffn2_out.sum(axis=1)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
@@ -496,9 +496,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
@@ -510,11 +510,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -537,14 +537,14 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
[0, 2, 1]).contiguous()
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
check layer is valid for this method
"""
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
@@ -563,8 +563,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
E, N1, _ = layer.moe_ffn1_weight.shape
N2 = layer.moe_ffn2_weight.shape[1]
E, N1, _ = layer.up_gate_proj_weight.shape
N2 = layer.down_proj_weight.shape[1]
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
@@ -605,10 +605,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x_q,
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
layer.up_gate_proj_weight.view(paddle.float8_e4m3fn),
intermediate_cache1,
x_scale,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -619,17 +619,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
K=hidden_size,
stride_am=x_q.strides[0],
stride_ak=x_q.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[2],
stride_bn=layer.moe_ffn1_weight.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[2],
stride_bn=layer.up_gate_proj_weight.strides[1],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
#
stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=layer.moe_ffn1_weight_scale.strides[2],
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0],
# Meta-parameters
@@ -656,10 +656,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x_q,
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
layer.down_proj_weight.view(paddle.float8_e4m3fn),
intermediate_cache3,
x_scale,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -670,16 +670,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
K=moe_intermediate_size,
stride_am=x_q.strides[0],
stride_ak=x_q.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[2],
stride_bn=layer.moe_ffn2_weight.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[2],
stride_bn=layer.down_proj_weight.strides[1],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=layer.moe_ffn2_weight_scale.strides[2],
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=layer.down_proj_weight_scale.strides[2],
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0],
# Meta-parameters

View File

@@ -41,16 +41,16 @@ class Wint2MoeMethod(QuantMethodBase):
"""
pass
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
check layer is valid for this method
"""
assert len(
ffn1_weights
) == layer.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
up_gate_proj_weights
) == layer.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts."
assert len(
ffn2_weights
) == layer.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
down_proj_weights
) == layer.num_local_experts, "down_proj_weights length should be equal to num_local_experts."
def create_weights(self, layer: nn.Layer, state_dict):
"""
@@ -78,96 +78,96 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
"""
Paddle cutlass process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
ffn1_expert_super_scales_key = layer.weight_key_map.get(
"ffn1_expert_super_scales_key", None)
ffn2_expert_super_scales_key = layer.weight_key_map.get(
"ffn2_expert_super_scales_key", None)
ffn1_expert_code_scale_key = layer.weight_key_map.get(
"ffn1_expert_code_scale_key", None)
ffn2_expert_code_scale_key = layer.weight_key_map.get(
"ffn2_expert_code_scale_key", None)
ffn1_expert_code_zp_key = layer.weight_key_map.get(
"ffn1_expert_code_zp_key", None)
ffn2_expert_code_zp_key = layer.weight_key_map.get(
"ffn2_expert_code_zp_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
up_gate_proj_expert_super_scales_key = layer.weight_key_map.get(
"up_gate_proj_expert_super_scales_key", None)
down_proj_expert_super_scales_key = layer.weight_key_map.get(
"down_proj_expert_super_scales_key", None)
up_gate_proj_expert_code_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_code_scale_key", None)
down_proj_expert_code_scale_key = layer.weight_key_map.get(
"down_proj_expert_code_scale_key", None)
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get(
"up_gate_proj_expert_code_zp_key", None)
down_proj_expert_code_zp_key = layer.weight_key_map.get(
"down_proj_expert_code_zp_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
ffn1_super_scales = []
ffn2_super_scales = []
ffn1_code_scale = []
ffn2_code_scale = []
ffn1_code_zp = []
ffn2_code_zp = []
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
up_gate_proj_super_scales = []
down_proj_super_scales = []
up_gate_proj_code_scale = []
down_proj_code_scale = []
up_gate_proj_code_zp = []
down_proj_code_zp = []
for i in range(layer.num_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
ffn1_super_scales.append(
down_proj_expert_weight_scale_key.format(expert_idx))))
up_gate_proj_super_scales.append(
get_tensor(
state_dict.pop(
ffn1_expert_super_scales_key.format(expert_idx))))
ffn2_super_scales.append(
up_gate_proj_expert_super_scales_key.format(expert_idx))))
down_proj_super_scales.append(
get_tensor(
state_dict.pop(
ffn2_expert_super_scales_key.format(expert_idx))))
ffn1_code_scale.append(
down_proj_expert_super_scales_key.format(expert_idx))))
up_gate_proj_code_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_code_scale_key.format(expert_idx))))
ffn2_code_scale.append(
up_gate_proj_expert_code_scale_key.format(expert_idx))))
down_proj_code_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_code_scale_key.format(expert_idx))))
ffn1_code_zp.append(
down_proj_expert_code_scale_key.format(expert_idx))))
up_gate_proj_code_zp.append(
get_tensor(
state_dict.pop(
ffn1_expert_code_zp_key.format(expert_idx))))
ffn2_code_zp.append(
up_gate_proj_expert_code_zp_key.format(expert_idx))))
down_proj_code_zp.append(
get_tensor(
state_dict.pop(
ffn2_expert_code_zp_key.format(expert_idx))))
down_proj_expert_code_zp_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
ffn1_super_scales = paddle.stack(ffn1_super_scales, axis=0)
ffn2_super_scales = paddle.stack(ffn2_super_scales, axis=0)
ffn1_code_scale = paddle.stack(ffn1_code_scale, axis=0)
ffn2_code_scale = paddle.stack(ffn2_code_scale, axis=0)
ffn1_code_zp = paddle.stack(ffn1_code_zp, axis=0)
ffn2_code_zp = paddle.stack(ffn2_code_zp, axis=0)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
up_gate_proj_super_scales = paddle.stack(up_gate_proj_super_scales, axis=0)
down_proj_super_scales = paddle.stack(down_proj_super_scales, axis=0)
up_gate_proj_code_scale = paddle.stack(up_gate_proj_code_scale, axis=0)
down_proj_code_scale = paddle.stack(down_proj_code_scale, axis=0)
up_gate_proj_code_zp = paddle.stack(up_gate_proj_code_zp, axis=0)
down_proj_code_zp = paddle.stack(down_proj_code_zp, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale,
"moe_ffn1_super_scales": ffn1_super_scales,
"moe_ffn2_super_scales": ffn2_super_scales,
"moe_ffn1_code_scale": ffn1_code_scale,
"moe_ffn2_code_scale": ffn2_code_scale,
"moe_ffn1_code_zp": ffn1_code_zp,
"moe_ffn2_code_zp": ffn2_code_zp
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_super_scales": up_gate_proj_super_scales,
"down_proj_super_scales": down_proj_super_scales,
"up_gate_proj_code_scale": up_gate_proj_code_scale,
"down_proj_code_scale": down_proj_code_scale,
"up_gate_proj_code_zp": up_gate_proj_code_zp,
"down_proj_code_zp": down_proj_code_zp
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -200,7 +200,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
x,
gate_out,
layer.gate_correction_bias,
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
@@ -210,17 +210,17 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
permute_input,
token_nums_per_expert,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None,
layer.moe_ffn1_super_scales,
layer.moe_ffn2_super_scales,
layer.moe_ffn1_weight_scale,
layer.moe_ffn1_code_scale,
layer.moe_ffn1_code_zp,
layer.moe_ffn2_weight_scale,
layer.moe_ffn2_code_scale,
layer.moe_ffn2_code_zp,
layer.up_gate_proj_super_scales,
layer.down_proj_super_scales,
layer.up_gate_proj_weight_scale,
layer.up_gate_proj_code_scale,
layer.up_gate_proj_code_zp,
layer.down_proj_weight_scale,
layer.down_proj_code_scale,
layer.down_proj_code_zp,
False,
)
@@ -271,7 +271,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
)
num_tokens, K = x.shape
E, _, N = layer.moe_ffn1_weight.shape
E, _, N = layer.up_gate_proj_weight.shape
M = num_tokens
top_k = topk_ids.shape[1]
@@ -308,12 +308,12 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
moe_wint2_ffn_kernel[grid](
x,
layer.moe_ffn1_weight,
layer.up_gate_proj_weight,
intermediate_cache1,
layer.moe_ffn1_weight_scale,
layer.moe_ffn1_super_scales,
layer.moe_ffn1_code_scale,
layer.moe_ffn1_code_zp,
layer.up_gate_proj_weight_scale,
layer.up_gate_proj_super_scales,
layer.up_gate_proj_code_scale,
layer.up_gate_proj_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -321,7 +321,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn1_weight.shape[-1],
N=layer.up_gate_proj_weight.shape[-1],
K=x.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
@@ -329,15 +329,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
# (A has M rows).
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache1.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=layer.moe_ffn1_weight_scale.strides[1],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=layer.up_gate_proj_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn1_code_scale.strides[0],
stride_bce=layer.up_gate_proj_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
@@ -361,17 +361,17 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
}
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), )
ceil_div(layer.down_proj_weight.shape[-1], config["BLOCK_SIZE_N"]), )
moe_wint2_ffn_kernel[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
layer.down_proj_weight,
intermediate_cache3,
layer.moe_ffn2_weight_scale,
layer.moe_ffn2_super_scales,
layer.moe_ffn2_code_scale,
layer.moe_ffn2_code_zp,
layer.down_proj_weight_scale,
layer.down_proj_super_scales,
layer.down_proj_code_scale,
layer.down_proj_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -379,7 +379,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn2_weight.shape[-1],
N=layer.down_proj_weight.shape[-1],
K=intermediate_cache2.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
@@ -387,15 +387,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
# (A has M rows).
stride_am=intermediate_cache2.strides[0],
stride_ak=1,
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache3.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=layer.moe_ffn2_weight_scale.strides[1],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=layer.down_proj_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn2_code_scale.strides[0],
stride_bce=layer.down_proj_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],

View File

@@ -38,14 +38,14 @@ class XPUMoEMethod(MoEMethodBase):
Paddle cutlass create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
for weights in [ffn1_weights, ffn2_weights]:
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
for weights in [up_gate_proj_weights, down_proj_weights]:
for idx, weight in enumerate(weights):
weights[idx] = weight.transpose([1, 0])
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
setattr(
layer, weight_name,
@@ -71,13 +71,13 @@ class XPUMoEMethod(MoEMethodBase):
x,
layer.gate_weight.transpose([1, 0]),
layer.gate_correction_bias,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
None, # ffn1 bias
None, # ffn2 bias
None, # ffn1 scale
None, # ffn2 scale
None, # ffn1_in_scale
layer.up_gate_proj_weight,
layer.down_proj_weight,
None, # up_gate_proj bias
None, # down_proj bias
None, # up_gate_proj scale
None, # down_proj scale
None, # up_gate_proj_in_scale
"", # moe_quant_type
layer.top_k,
False, # moe group, used in deepseek
@@ -129,20 +129,20 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
assert ffn1_weights[0].shape == [
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
added_scale_attrs = ["up_gate_proj_weight_scale", "down_proj_weight_scale"]
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = added_weight_attrs[idx]
scale_name = added_scale_attrs[idx]
@@ -189,16 +189,16 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
x,
layer.gate_weight.transpose([1, 0]),
layer.gate_correction_bias,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
None, # ffn1 bias
None, # ffn2 bias
(layer.moe_ffn1_weight_scale
if hasattr(layer, "moe_ffn1_weight_scale") else None),
(layer.moe_ffn2_weight_scale
if hasattr(layer, "moe_ffn2_weight_scale") else None),
(layer.moe_ffn2_in_scale
if hasattr(layer, "moe_ffn2_in_scale") else None),
layer.up_gate_proj_weight,
layer.down_proj_weight,
None, # up_gate_proj bias
None, # down_proj bias
(layer.up_gate_proj_weight_scale
if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale
if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale
if hasattr(layer, "down_proj_in_scale") else None),
self.moe_quant_type,
layer.top_k,
False, # moe group, used in deepseek

View File

@@ -145,13 +145,13 @@ class FusedMoE(nn.Layer):
shape=gate_correction_bias_shape,
dtype="float32",
)
ffn1_output_dim = self.moe_intermediate_size * 2
up_gate_proj_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["fp8", "wint8"]:
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
up_gate_proj_weight_shape = [self.num_local_experts, up_gate_proj_output_dim, self.hidden_size]
down_proj_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
else:
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
up_gate_proj_weight_shape = [self.num_local_experts, self.hidden_size, up_gate_proj_output_dim]
down_proj_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
# Create parameters
if self.moe_quant_type == "fp8":
@@ -161,15 +161,15 @@ class FusedMoE(nn.Layer):
self.weight_dtype = "int8"
self.init_weight_only_scale()
# FFN1 parameters
self.moe_ffn1_weight = self.create_parameter(
shape=ffn1_weight_shape,
# up_gate_proj parameters
self.up_gate_proj_weight = self.create_parameter(
shape=up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
# FFN2 parameters
self.moe_ffn2_weight = self.create_parameter(
shape=ffn2_weight_shape,
# down_proj parameters
self.down_proj_weight = self.create_parameter(
shape=down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
@@ -178,44 +178,44 @@ class FusedMoE(nn.Layer):
"""
Initialize the weight scale.
"""
self.moe_ffn1_weight_scale = self.create_parameter(
self.up_gate_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
dtype=self._dtype,
)
self.moe_ffn2_weight_scale = self.create_parameter(
self.down_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size],
dtype=self._dtype,
)
def load_experts_weight(self, state_dict: dict,
ffn1_expert_weight_key: str,
ffn2_expert_weight_key: str):
up_gate_proj_expert_weight_key: str,
down_proj_expert_weight_key: str):
"""
Load experts weight from state_dict.
Args:
state_dict (dict): The state_dict of model.
ffn1_expert_weight_key (str): The key of ffn1 expert weight.
ffn2_expert_weight_key (str): The key of ffn2 expert weight.
up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight.
down_proj_expert_weight_key (str): The key of down_proj expert weight.
"""
ffn1_weights = []
ffn2_weights = []
is_ffn_merged = ffn1_expert_weight_key.format(
up_gate_proj_weights = []
down_proj_weights = []
is_ffn_merged = up_gate_proj_expert_weight_key.format(
self.expert_id_offset) in state_dict
if is_ffn_merged:
for i in range(self.num_local_experts):
expert_idx = self.expert_id_offset + i
ffn1_weights.append(
up_gate_proj_weights.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_key.format(expert_idx))))
ffn2_weights.append(
up_gate_proj_expert_weight_key.format(expert_idx))))
down_proj_weights.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_key.format(expert_idx))))
down_proj_expert_weight_key.format(expert_idx))))
else:
gate_expert_weight_key = ffn1_expert_weight_key.replace(
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace(
"up_gate_proj", "gate_proj")
up_expert_weight_key = ffn1_expert_weight_key.replace(
up_expert_weight_key = up_gate_proj_expert_weight_key.replace(
"up_gate_proj", "up_proj")
for j in range(self.num_local_experts):
expert_idx = self.expert_id_offset + j
@@ -223,12 +223,12 @@ class FusedMoE(nn.Layer):
state_dict.pop(gate_expert_weight_key.format(expert_idx)))
up = get_tensor(
state_dict.pop(up_expert_weight_key.format(expert_idx)))
ffn1_weights.append(paddle.concat([gate, up], axis=-1))
ffn2_weights.append(
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
down_proj_weights.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_key.format(expert_idx))))
return ffn1_weights, ffn2_weights
down_proj_expert_weight_key.format(expert_idx))))
return up_gate_proj_weights, down_proj_weights
def extract_moe_ffn_weights(self, state_dict: dict):
"""
@@ -239,30 +239,30 @@ class FusedMoE(nn.Layer):
Returns:
tuple: A tuple containing two lists:
- ffn1_weights: List of tensors for first FFN layer weights
- ffn2_weights: List of tensors for second FFN layer weights
- up_gate_proj_weights: List of tensors for first FFN layer weights
- down_proj_weights: List of tensors for second FFN layer weights
Raises:
AssertionError: If required weight keys are missing or number of weights
doesn't match number of local experts.
"""
ffn1_expert_weight_key = self.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = self.weight_key_map.get(
"ffn2_expert_weight_key", None)
assert ffn1_expert_weight_key is not None, "ffn1_expert_weight_key should not be none."
assert ffn2_expert_weight_key is not None, "ffn2_expert_weight_key should not be none."
up_gate_proj_expert_weight_key = self.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = self.weight_key_map.get(
"down_proj_expert_weight_key", None)
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
ffn1_weights, ffn2_weights = self.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
up_gate_proj_weights, down_proj_weights = self.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
assert len(
ffn1_weights
) == self.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
up_gate_proj_weights
) == self.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts."
assert len(
ffn2_weights
) == self.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
down_proj_weights
) == self.num_local_experts, "down_proj_weights length should be equal to num_local_experts."
return ffn1_weights, ffn2_weights
return up_gate_proj_weights, down_proj_weights
def extract_gate_correction_bias(self, gate_correction_bias_key,
state_dict):

View File

@@ -46,11 +46,11 @@ class ParallelEHProjection(nn.Layer):
prefix (str): full name of the layer in the state dict
"""
super(ParallelEHProjection, self).__init__()
self.linear_weight_key = prefix + ".weight"
self.weight_key = prefix + ".weight"
if with_bias:
self.linear_bias_key = prefix + ".bias"
self.bias_key = prefix + ".bias"
else:
self.linear_bias_key = None
self.bias_key = None
self.use_ep = fd_config.parallel_config.use_ep
self.column_cut = True
@@ -66,26 +66,26 @@ class ParallelEHProjection(nn.Layer):
else:
if self.column_cut:
need_gather = True
self.out_linear = ColumnParallelLinear(
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
else:
self.out_linear = RowParallelLinear(
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
@@ -100,20 +100,20 @@ class ParallelEHProjection(nn.Layer):
if self.use_ep:
self.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()))
else:
weight_tensor = get_tensor(
state_dict.pop(self.linear_weight_key)).astype(
state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype())
if self.out_linear.weight.shape != weight_tensor.shape:
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.out_linear.weight.set_value(weight_tensor)
self.linear.weight.set_value(weight_tensor)
if self.linear_bias_key is not None:
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
paddle.get_default_dtype())
self.out_linear.bias.set_value(bias)
self.linear.bias.set_value(bias)
def forward(self, input):
"""
@@ -129,5 +129,5 @@ class ParallelEHProjection(nn.Layer):
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.out_linear(logits)
logits = self.linear(logits)
return logits

View File

@@ -43,7 +43,7 @@ class RMSNorm(nn.Layer):
hidden_size: int,
eps: float = 1e-5,
prefix: str = "",
linear_bias: paddle.Tensor = None,
bias: paddle.Tensor = None,
quant_scale: float = None,
begin_norm_axis: int = 1,
) -> None:
@@ -57,7 +57,7 @@ class RMSNorm(nn.Layer):
hidden_size (int) : size of hidden state.
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
prefix(str,optional):The name of current layer. Defaults to "".
linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
@@ -78,7 +78,7 @@ class RMSNorm(nn.Layer):
self.norm_func: Callable = fused_add_rms_norm
else:
self.norm_func: Callable = fused_rms_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.bias: Optional[paddle.Tensor] = bias
self.quant_scale: Optional[float] = quant_scale
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = self._dtype
@@ -94,9 +94,9 @@ class RMSNorm(nn.Layer):
Initialize the weights and biases.
"""
self.ln_weight = None
self.weight = None
if self.with_weight:
self.ln_weight = self.create_parameter(
self.weight = self.create_parameter(
shape=[self.hidden_size],
default_initializer=nn.initializer.Constant(value=1.0),
dtype=self._norm_weight_dtype,
@@ -115,7 +115,7 @@ class RMSNorm(nn.Layer):
weight_tensor = paddle.cast(
get_tensor(state_dict.pop(self.weight_key)),
self._norm_weight_dtype)
self.ln_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
def forward(
self,
@@ -139,18 +139,18 @@ class RMSNorm(nn.Layer):
"""
if current_platform.is_gcu():
if residual_input is None:
return rms_norm(x, self.ln_weight, self.eps)
return rms_norm(x, self.weight, self.eps)
norm_out = self.norm_func(
x, residual_input, self.ln_weight, self.eps
x, residual_input, self.weight, self.eps
)
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.linear_bias,
bias=self.bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
@@ -174,7 +174,7 @@ class LayerNorm(nn.Layer):
hidden_size: int,
eps: float = 1e-5,
prefix="",
linear_bias: paddle.Tensor = None,
bias: paddle.Tensor = None,
quant_scale: float = None,
with_bias: bool = False,
):
@@ -189,7 +189,7 @@ class LayerNorm(nn.Layer):
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
with_bias (bool):Whether to include bias or not. Defaults to False.
Raises:
@@ -212,7 +212,7 @@ class LayerNorm(nn.Layer):
self.norm_func: Callable = paddle.nn.functional.layer_norm
else:
self.norm_func: Callable = fused_layer_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.bias: Optional[paddle.Tensor] = bias
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = "float32"
@@ -227,16 +227,16 @@ class LayerNorm(nn.Layer):
Initialize the weights and biases.
"""
self.ln_weight = None
self.weight = None
if self.with_weight:
self.ln_weight = self.create_parameter(
self.weight = self.create_parameter(
shape=[self.hidden_size],
default_initializer=nn.initializer.Constant(value=1.0),
dtype=self._norm_weight_dtype,
)
self.ln_bias = None
self.bias = None
if self.with_bias:
self.ln_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.hidden_size],
is_bias=True,
dtype=self._norm_weight_dtype,
@@ -255,14 +255,14 @@ class LayerNorm(nn.Layer):
weight_tensor = paddle.cast(
get_tensor(state_dict.pop(self.weight_key)),
self._norm_weight_dtype)
self.ln_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
# bias
if self.with_bias:
bias_tensor = paddle.cast(
get_tensor(state_dict.pop(self.bias_key)),
self._norm_weight_dtype)
self.ln_bias.set_value(bias_tensor)
self.bias.set_value(bias_tensor)
def forward(
self,
@@ -285,10 +285,10 @@ class LayerNorm(nn.Layer):
operations (like linear transformation) on the `residual_input`.
"""
if current_platform.is_iluvatar():
if self.ln_weight is None and self.ln_bias is None:
if self.weight is None and self.bias is None:
out = x
if self.linear_bias is not None:
out += self.linear_bias
if self.bias is not None:
out += self.bias
if residual_input is not None:
out += residual_input
return out, out
@@ -303,8 +303,8 @@ class LayerNorm(nn.Layer):
out = self.norm_func(
x=y,
normalized_shape=y.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
weight=self.weight,
bias=self.bias,
epsilon=self.eps,
)
return out, y
@@ -312,19 +312,19 @@ class LayerNorm(nn.Layer):
out = self.norm_func(
x=x,
normalized_shape=x.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
weight=self.weight,
bias=self.bias,
epsilon=self.eps,
)
return out
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=self.ln_bias,
norm_weight=self.weight,
norm_bias=self.bias,
epsilon=self.eps,
begin_norm_axis=1,
bias=self.linear_bias,
bias=self.bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,

View File

@@ -78,8 +78,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
layer.linear_weight_shape.reverse()
layer.linear_weight_scale = layer.create_parameter(
layer.weight_shape.reverse()
layer.weight_scale = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] -
1) // self.quant_config.weight_block_size[0],
@@ -95,8 +95,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
weight_tensor = weights.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = (
per_block_cast_to_fp8(weight_tensor))
layer.linear_weight.copy_(quanted_weight_tensor, False)
layer.linear_weight_scale.set_value(weight_block_scale_tensor)
layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale.set_value(weight_block_scale_tensor)
def process_prequanted_weights(self, layer, state_dict):
"""
@@ -106,10 +106,10 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
quant_weight = quant_weight.transpose([1, 0]).contiguous()
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
weight_scale = weight_scale.transpose([1, 0])
layer.linear_weight_scale.set_value(weight_scale)
layer.weight_scale.set_value(weight_scale)
def apply(self, layer, x):
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
@@ -119,9 +119,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
deep_gemm.gemm_fp8_fp8_bf16_nt(
(x, x_scale_tensor),
(layer.linear_weight, layer.linear_weight_scale),
(layer.weight, layer.weight_scale),
linear_out,
)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.linear_bias)
linear_out = paddle.add(linear_out, layer.bias)
return linear_out

View File

@@ -96,7 +96,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
quant_weight = quant_weight.transpose([1, 0]).contiguous()
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
self.act_scale = act_scale.item()
self.total_scale = (act_scale * weight_scale).item()
@@ -118,7 +118,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
linear_out = cutlass_fp8_fp8_half_gemm_fused(
fp8_x,
layer.linear_weight,
layer.weight,
transpose_x=False,
transpose_y=True,
bias=None,

View File

@@ -63,8 +63,8 @@ class W4AFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
layer.linear_weight_shape.reverse()
layer.linear_weight_shape[0] //= 2
layer.weight_shape.reverse()
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
pass
@@ -77,16 +77,16 @@ class W4AFP8LinearMethod(QuantMethodBase):
scale_dtype="float16",
))
weight_scale_tensor = paddle.view(weight_scale_tensor, layer._dtype)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(weight_scale_tensor)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor)
def apply(self, layer, x):
linear_out = fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16(
x,
layer.linear_weight,
layer.linear_weight_scale,
layer.weight,
layer.weight_scale,
zero_points=None,
bias=layer.linear_bias if layer.add_bias else None,
bias=layer.bias if layer.add_bias else None,
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
".weight_scale")
/ (self.quant_config.act_scale_dict.get(layer.prefix +

View File

@@ -69,7 +69,7 @@ class W8A8LinearMethod(QuantMethodBase):
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
def create_weights(self, layer):
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
layer.weight_dtype = "int8"
if self.quant_config.use_smooth_quant:
self.smooth_quant_method.create_weights(layer)
@@ -101,21 +101,21 @@ class W8A8LinearMethod(QuantMethodBase):
if self.skip_quant:
logger.debug(f"{layer.prefix} skip quant")
weight_tensor = weights.cast(layer._dtype)
layer.linear_weight.set_value(weight_tensor)
layer.weight.set_value(weight_tensor)
else:
weight_tensor = weights.transpose([1, 0])
weight_tensor = paddle.cast(weight_tensor, "int8")
layer.linear_weight.set_value(weight_tensor)
layer.weight.set_value(weight_tensor)
def apply(self, layer, x):
if self.skip_quant:
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.quant_config.use_gemm_dequant:
linear_out = fastdeploy.model_executor.ops.gpu.gemm_dequant(
x, layer.linear_weight, layer.linear_out_scale, layer._dtype)
x, layer.weight, layer.linear_out_scale, layer._dtype)
else:
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
linear_out = paddle.matmul(x, layer.weight, False, True)
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
linear_out, layer.linear_out_scale, layer._dtype)
return linear_out

View File

@@ -77,12 +77,12 @@ class WeightOnlyConfig(QuantConfigBase):
return GCUWeightOnlyLinearMethod(self)
elif current_platform.is_dcu():
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.backends import (
DCUTritonWeightOnlyMoEMethod)
from fastdeploy.model_executor.layers.backends import \
DCUTritonWeightOnlyMoEMethod
return DCUTritonWeightOnlyMoEMethod(self)
else:
from fastdeploy.model_executor.layers.backends import (
DCUWeightOnlyLinearMethod)
from fastdeploy.model_executor.layers.backends import \
DCUWeightOnlyLinearMethod
return DCUWeightOnlyLinearMethod(self)
else:
if isinstance(layer, FusedMoE):
@@ -152,14 +152,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
weight_scale_shape = [layer.weight_shape[1]]
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
@@ -171,9 +171,9 @@ class WeightOnlyLinearMethod(QuantMethodBase):
def apply(self, layer, x):
linear_out = weight_only_linear(
x,
weight=layer.linear_weight,
bias=layer.linear_bias if layer.add_bias else None,
weight_scale=layer.linear_weight_scale,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype="int8"
if self.quant_config.name() == "wint8" else "int4",
arch=self.quant_config.weight_only_linear_arch,
@@ -204,8 +204,8 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.linear_weight.set_value(quant_weight)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quant_weight)
layer.weight_scale.set_value(
weight_scale.astype(paddle.get_default_dtype()))
def process_loaded_weights(self, layer, weight) -> None:
@@ -216,6 +216,6 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
arch=self.quant_config.weight_only_linear_arch,
)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(
weight_scale_tensor.astype(paddle.get_default_dtype()))

View File

@@ -70,11 +70,11 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
def create_weights(self, layer):
"""
"""
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.linear_weight_scale = layer.create_parameter(
layer.weight_scale = layer.create_parameter(
shape=[1],
dtype="float32",
is_bias=False,
@@ -86,7 +86,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
"""
if self.skip_quant:
weight_tensor = weights.cast(layer._dtype)
layer.linear_weight.set_value(weight_tensor)
layer.weight.set_value(weight_tensor)
return
if weights.dtype != paddle.float8_e4m3fn:
self.use_per_token_if_dynamic = True
@@ -95,22 +95,22 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
weight_tensor,
use_per_token_if_dynamic=False,
)
layer.linear_weight.copy_(qweight, False)
layer.linear_weight_scale.set_value(weight_scale)
layer.weight.copy_(qweight, False)
layer.weight_scale.set_value(weight_scale)
def apply(self, layer, x):
"""
"""
if self.skip_quant:
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.use_per_token_if_dynamic:
out_type = x.dtype
a_q, a_scales = scaled_fp8_quant(
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales,
layer.linear_weight_scale, out_type,
layer.linear_bias)
linear_out = cutlass_scaled_mm(a_q, layer.weight, a_scales,
layer.weight_scale, out_type,
layer.bias)
else:
raise NotImplementedError
return linear_out

View File

@@ -48,22 +48,22 @@ def load_ep_checkpoint(model_path: str,
config.num_experts_start_offset,
config.num_experts_start_offset + config.num_experts_per_rank,
):
ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
ffn2_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
ffn2_quant_key = (
up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
down_proj_quant_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
ffn2_scale_key = (
up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
down_proj_scale_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
num_local_ffn_keys.append(ffn1_key)
num_local_ffn_keys.append(ffn2_key)
num_local_ffn_keys.append(ffn1_quant_key)
num_local_ffn_keys.append(ffn2_quant_key)
num_local_ffn_keys.append(ffn1_scale_key)
num_local_ffn_keys.append(ffn2_scale_key)
num_local_ffn_keys.append(up_gate_proj_key)
num_local_ffn_keys.append(down_proj_key)
num_local_ffn_keys.append(up_gate_proj_quant_key)
num_local_ffn_keys.append(down_proj_quant_key)
num_local_ffn_keys.append(up_gate_proj_scale_key)
num_local_ffn_keys.append(down_proj_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:

View File

@@ -61,7 +61,7 @@ class DeepSeekV3MLP(nn.Layer):
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
self.up_gate_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
@@ -88,13 +88,13 @@ class DeepSeekV3MLP(nn.Layer):
def load_state_dict(self, state_dict):
"""
"""
self.gate_up_proj.load_state_dict(state_dict)
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
"""
"""
gate_up_out = self.gate_up_proj(x)
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
@@ -115,9 +115,9 @@ class DeepSeekV3MoE(nn.Layer):
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.gate.e_score_correction_bias",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
@@ -528,7 +528,7 @@ class DeepSeekV3Model(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "deepseek_v3"
self.embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
@@ -554,7 +554,7 @@ class DeepSeekV3Model(nn.Layer):
"""
Load model parameters from a given state dictionary.
"""
self.embeddings.load_state_dict(state_dict)
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
@@ -569,7 +569,7 @@ class DeepSeekV3Model(nn.Layer):
):
"""
"""
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
for i in range(self.num_layers):

View File

@@ -23,6 +23,7 @@ import numpy as np
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
@@ -55,7 +56,7 @@ class Ernie4_5_MLP(nn.Layer):
) -> None:
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.gate_up_proj = MergedColumnParallelLinear(
self.up_gate_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
@@ -79,11 +80,11 @@ class Ernie4_5_MLP(nn.Layer):
)
def load_state_dict(self, state_dict):
self.gate_up_proj.load_state_dict(state_dict)
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
gate_up_out = self.gate_up_proj(hidden_states)
gate_up_out = self.up_gate_proj(hidden_states)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
@@ -104,17 +105,17 @@ class Ernie4_5_MoE(nn.Layer):
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.quant_weight",
"ffn1_expert_weight_scale_key":
"up_gate_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"ffn2_expert_weight_scale_key":
"down_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.down_proj.weight_scale",
"ffn1_expert_in_scale_key":
"up_gate_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"ffn2_expert_in_scale_key":
"down_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
elif moe_quant_type == "w4w2":
@@ -123,25 +124,25 @@ class Ernie4_5_MoE(nn.Layer):
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.quant_weight",
"ffn1_expert_weight_scale_key":
"up_gate_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"ffn2_expert_weight_scale_key":
"down_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.down_proj.weight_scale",
"ffn1_expert_super_scales_key":
"up_gate_proj_expert_super_scales_key":
f"{prefix}.experts.{{}}.up_gate_proj.super_scales",
"ffn2_expert_super_scales_key":
"down_proj_expert_super_scales_key":
f"{prefix}.experts.{{}}.down_proj.super_scales",
"ffn1_expert_code_scale_key":
"up_gate_proj_expert_code_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.code_scale",
"ffn2_expert_code_scale_key":
"down_proj_expert_code_scale_key":
f"{prefix}.experts.{{}}.down_proj.code_scale",
"ffn1_expert_code_zp_key":
"up_gate_proj_expert_code_zp_key":
f"{prefix}.experts.{{}}.up_gate_proj.code_zp",
"ffn2_expert_code_zp_key":
"down_proj_expert_code_zp_key":
f"{prefix}.experts.{{}}.down_proj.code_zp",
}
elif moe_quant_type == "tensor_wise_fp8" or (
@@ -152,17 +153,17 @@ class Ernie4_5_MoE(nn.Layer):
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.quant_weight",
"ffn1_expert_weight_scale_key":
"up_gate_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"ffn2_expert_weight_scale_key":
"down_proj_expert_weight_scale_key":
f"{prefix}.experts.{{}}.down_proj.weight_scale",
"ffn1_expert_in_scale_key":
"up_gate_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"ffn2_expert_in_scale_key":
"down_proj_expert_in_scale_key":
f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
@@ -171,9 +172,9 @@ class Ernie4_5_MoE(nn.Layer):
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
@@ -271,7 +272,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
prefix=f"{prefix}.self_attn",
)
if (fd_config.model_config.moe_num_experts is not None
if (getattr(fd_config.model_config, "moe_num_experts", None) is not None
and layer_id >= fd_config.model_config.moe_layer_start_index):
self.mlp = Ernie4_5_MoE(
fd_config=fd_config,
@@ -349,14 +350,14 @@ class Ernie4_5_Model(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype(),
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"))
self.hidden_layers = nn.LayerList([
self.layers = nn.LayerList([
Ernie4_5_DecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}")
@@ -379,22 +380,22 @@ class Ernie4_5_Model(nn.Layer):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embeddings.load_state_dict(state_dict)
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.hidden_layers[i].load_state_dict(state_dict)
self.layers[i].load_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.hidden_layers[i](forward_meta,
hidden_states, residual = self.layers[i](forward_meta,
hidden_states,
residual)
@@ -417,7 +418,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
"""
super(Ernie4_5_MoeForCausalLM, self).__init__(fd_config)
self.fd_config = fd_config
self.model = Ernie4_5_Model(fd_config=fd_config)
self.ernie = Ernie4_5_Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
@@ -444,10 +445,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.model.load_state_dict(state_dict)
self.ernie.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.out_linear.weight.set_value(
self.model.embeddings.word_embeddings.weight.transpose([1, 0]))
self.lm_head.linear.weight.set_value(
self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
else:
self.lm_head.load_state_dict(state_dict)
@@ -468,14 +469,14 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
)
for i in range(self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers):
self.model.hidden_layers[i].mlp.fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.model(ids_remove_padding=ids_remove_padding,
hidden_states = self.ernie(ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta)
return hidden_states
@@ -559,7 +560,7 @@ class Ernie4_5_PretrainedModel(PretrainedModel):
]
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""

View File

@@ -263,9 +263,9 @@ class Ernie4_5_MTPModel(nn.Layer):
super().__init__()
self.num_layers = fd_config.model_config.num_hidden_layers
self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
self.hidden_layers = nn.LayerList([
self.layers = nn.LayerList([
Ernie4_5_DecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.{i}")
@@ -302,13 +302,13 @@ class Ernie4_5_MTPModel(nn.Layer):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
# self.embeddings.load_state_dict(state_dict)
# self.embed_tokens.load_state_dict(state_dict)
self.enorm.load_state_dict(state_dict)
self.hnorm.load_state_dict(state_dict)
self.eh_proj.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.hidden_layers[i].load_state_dict(state_dict)
self.layers[i].load_state_dict(state_dict)
def forward(
self,
@@ -319,7 +319,7 @@ class Ernie4_5_MTPModel(nn.Layer):
"""
forward
"""
inputs_embedding = self.embeddings(
inputs_embedding = self.embed_tokens(
ids_remove_padding=ids_remove_padding)
inputs_embedding = paddle.concat(
[self.enorm(inputs_embedding),
@@ -328,7 +328,7 @@ class Ernie4_5_MTPModel(nn.Layer):
hidden_states = self.eh_proj(inputs_embedding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.hidden_layers[i](forward_meta,
hidden_states, residual = self.layers[i](forward_meta,
hidden_states,
residual)
@@ -349,7 +349,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
"""
super(Ernie4_5_MTPForCausalLM, self).__init__(fd_config)
self.fd_config = fd_config
self.model = Ernie4_5_MTPModel(fd_config=fd_config)
self.ernie = Ernie4_5_MTPModel(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
@@ -373,10 +373,10 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.model.load_state_dict(state_dict)
self.ernie.load_state_dict(state_dict)
# if self.tie_word_embeddings:
# self.lm_head.out_linear.weight.set_value(
# self.model.embeddings.word_embeddings.weight.transpose([1, 0]))
# self.lm_head.linear.weight.set_value(
# self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
# else:
# self.lm_head.load_state_dict(state_dict)
@@ -400,7 +400,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
)
for i in range(self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers):
self.model.hidden_layers[i].mlp.fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states)
def forward(
self,
@@ -411,7 +411,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
"""
forward
"""
hidden_states = self.model(ids_remove_padding, previous_hidden_states,
hidden_states = self.ernie(ids_remove_padding, previous_hidden_states,
forward_meta)
return hidden_states

View File

@@ -24,9 +24,10 @@ import numpy as np
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.graph_optimization.decorator import \
@@ -99,12 +100,12 @@ class Ernie4_5_VLMoE(nn.Layer):
f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
self.mlp_text = FusedMoE(
self.text_fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.
@@ -116,9 +117,9 @@ class Ernie4_5_VLMoE(nn.Layer):
moe_tag="Text",
weight_key_map=weight_key_map,
)
self.mlp_text.extract_gate_correction_bias = self.extract_gate_correction_bias_text
self.text_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_text
else:
self.mlp_text = Ernie4_5_VLMLP(
self.text_fused_moe = Ernie4_5_VLMLP(
fd_config=fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}",
@@ -131,12 +132,12 @@ class Ernie4_5_VLMoE(nn.Layer):
f"{prefix}.gate.weight_1",
"gate_correction_bias_key":
f"{prefix}.moe_statics.e_score_correction_bias",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
self.mlp_image = FusedMoE(
self.image_fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.
@@ -148,9 +149,9 @@ class Ernie4_5_VLMoE(nn.Layer):
moe_tag="Image",
weight_key_map=weight_key_map,
)
self.mlp_image.extract_gate_correction_bias = self.extract_gate_correction_bias_image
self.image_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_image
else:
self.mlp_image = Ernie4_5_VLMLP(
self.image_fused_moe = Ernie4_5_VLMLP(
fd_config=fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}",
@@ -185,10 +186,10 @@ class Ernie4_5_VLMoE(nn.Layer):
return gate_correction_bias_tensor[1].unsqueeze(0)
def load_state_dict(self, state_dict):
self.mlp_text.load_state_dict(state_dict)
self.mlp_image.load_state_dict(state_dict)
if self.mlp_text.moe_use_gate_correction_bias:
state_dict.pop(self.mlp_text.gate_correction_bias_key)
self.text_fused_moe.load_state_dict(state_dict)
self.image_fused_moe.load_state_dict(state_dict)
if self.text_fused_moe.moe_use_gate_correction_bias:
state_dict.pop(self.text_fused_moe.gate_correction_bias_key)
if self.num_shared_experts > 0:
self.share_experts.load_state_dict(state_dict)
@@ -205,8 +206,8 @@ class Ernie4_5_VLMoE(nn.Layer):
vl_moe_meta.image_index,
True,
)
text_out = self.mlp_text(vl_moe_meta.text_input)
image_out = self.mlp_image(vl_moe_meta.image_input)
text_out = self.text_fused_moe(vl_moe_meta.text_input)
image_out = self.image_fused_moe(vl_moe_meta.image_input)
text_image_gather_scatter(
hidden_states,
text_out,
@@ -217,7 +218,7 @@ class Ernie4_5_VLMoE(nn.Layer):
False,
)
else:
hidden_states = self.mlp_text(hidden_states)
hidden_states = self.text_fused_moe(hidden_states)
if self.num_shared_experts > 0:
hidden_states += share_experts_out
if self.tp_size > 1:
@@ -342,7 +343,7 @@ class Ernie4_5_VLModel(nn.Layer):
self._dtype = fd_config.model_config.dtype
fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
@@ -350,7 +351,7 @@ class Ernie4_5_VLModel(nn.Layer):
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
)
self.hidden_layers = nn.LayerList([
self.layers = nn.LayerList([
Ernie4_5_VLDecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}")
@@ -373,11 +374,11 @@ class Ernie4_5_VLModel(nn.Layer):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embeddings.load_state_dict(state_dict)
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.hidden_layers[i].load_state_dict(state_dict)
self.layers[i].load_state_dict(state_dict)
def forward(
self,
@@ -391,7 +392,7 @@ class Ernie4_5_VLModel(nn.Layer):
image_index = None
image_token_num = 0
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
# -----------------------
image_mask = ids_remove_padding == self.im_patch_id
@@ -424,7 +425,7 @@ class Ernie4_5_VLModel(nn.Layer):
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.hidden_layers[i](
hidden_states, residual = self.layers[i](
forward_meta,
hidden_states,
residual,
@@ -539,8 +540,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.vision_model.load_state_dict(state_dict)
self.resampler_model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.out_linear.weight.set_value(
self.ernie.embeddings.word_embeddings.weight.transpose([1, 0]))
self.lm_head.linear.weight.set_value(
self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
else:
self.lm_head.load_state_dict(state_dict)
@@ -666,7 +667,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
]
@classmethod
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""
@@ -686,10 +687,10 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.vision_config.num_heads,
num_key_value_heads=config.vision_config.num_heads,
head_dim=config.vision_config.hidden_size
// config.vision_config.num_heads,
num_attention_heads=config.vision_config.get("num_heads"),
num_key_value_heads=config.vision_config.get("num_heads"),
head_dim=config.vision_config.get("hidden_size")
// config.vision_config.get("num_heads"),
)
def get_tensor_parallel_split_mappings(
@@ -754,7 +755,7 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
config.prefix_name,
)
vision_mappings = get_vison_parallel_split_mappings(
config.vision_config.depth
config.vision_config.get("depth")
)
return {**mappings, **vision_mappings}

View File

@@ -48,7 +48,7 @@ class Qwen2MLP(nn.Layer):
) -> None:
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.gate_up_proj = MergedColumnParallelLinear(
self.up_gate_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
@@ -67,20 +67,20 @@ class Qwen2MLP(nn.Layer):
self.act_fn = SiluAndMul(
fd_config=fd_config,
bias=getattr(self.gate_up_proj, "linear_bias", None),
bias=getattr(self.up_gate_proj, "bias", None),
act_method=fd_config.model_config.hidden_act,
)
def load_state_dict(self, state_dict):
"""
"""
self.gate_up_proj.load_state_dict(state_dict)
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
"""
"""
gate_up_out = self.gate_up_proj(x)
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
@@ -230,7 +230,7 @@ class Qwen2Model(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "qwen2"
self.embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
@@ -261,7 +261,7 @@ class Qwen2Model(nn.Layer):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embeddings.load_state_dict(state_dict)
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
@@ -275,7 +275,7 @@ class Qwen2Model(nn.Layer):
"""
"""
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
@@ -303,7 +303,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
super(Qwen2ForCausalLM, self).__init__(fd_config)
self.fd_config =fd_config
self.model = Qwen2Model(fd_config=fd_config)
self.qwen2 = Qwen2Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
@@ -330,7 +330,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.model.load_state_dict(state_dict)
self.qwen2.load_state_dict(state_dict)
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
@@ -349,7 +349,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
):
"""
"""
hidden_states = self.model(ids_remove_padding=ids_remove_padding,
hidden_states = self.qwen2(ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta)
return hidden_states

View File

@@ -166,7 +166,7 @@ class Qwen3Model(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "model"
self.embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
@@ -197,7 +197,7 @@ class Qwen3Model(nn.Layer):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embeddings.load_state_dict(state_dict)
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
@@ -210,7 +210,7 @@ class Qwen3Model(nn.Layer):
):
"""
"""
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
@@ -266,8 +266,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
"""
self.model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.out_linear.weight.set_value(
self.model.embeddings.word_embeddings.weight.transpose([1, 0]))
self.lm_head.linear.weight.set_value(
self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
else:
self.lm_head.load_state_dict(state_dict)

View File

@@ -50,7 +50,7 @@ class Qwen3MLP(nn.Layer):
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.gate_up_proj = MergedColumnParallelLinear(
self.up_gate_proj = MergedColumnParallelLinear(
fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
@@ -69,20 +69,20 @@ class Qwen3MLP(nn.Layer):
self.act_fn = SiluAndMul(
fd_config,
bias=getattr(self.gate_up_proj, "linear_bias", None),
bias=getattr(self.up_gate_proj, "bias", None),
act_method=fd_config.model_config.hidden_act,
)
def load_state_dict(self, state_dict):
"""
"""
self.gate_up_proj.load_state_dict(state_dict)
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
"""
"""
gate_up_out = self.gate_up_proj(x)
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
@@ -108,9 +108,9 @@ class Qwen3DecoderLayer(nn.Layer):
weight_key_map = {
"gate_weight_key":
f"{prefix}.mlp.gate.weight",
"ffn1_expert_weight_key":
"up_gate_proj_expert_weight_key":
f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight",
"ffn2_expert_weight_key":
"down_proj_expert_weight_key":
f"{prefix}.mlp.experts.{{}}.down_proj.weight",
}
@@ -201,7 +201,7 @@ class Qwen3MoeModel(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "model"
self.embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
@@ -232,7 +232,7 @@ class Qwen3MoeModel(nn.Layer):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embeddings.load_state_dict(state_dict)
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
@@ -245,7 +245,7 @@ class Qwen3MoeModel(nn.Layer):
):
"""
"""
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None

View File

@@ -311,18 +311,18 @@ def w4a8_weight_convert(state_dict):
w4a8_weight_bites_layers_map = {}
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["out_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"] = []
for name_keys, gemm_bits in w4a8_weight_bites_name_map.items():
if "qkv_proj" in name_keys:
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"].append(gemm_bits)
elif "out_proj" in name_keys:
w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits)
elif "linear1" in name_keys:
w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"].append(
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(
gemm_bits)
elif "linear2" in name_keys:
w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"].append(
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(
gemm_bits)
logger.debug(
f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")

View File

@@ -15,9 +15,10 @@
"""
from typing import Optional
import paddle
from paddle.nn.quant import weight_only_linear
from paddle.incubate.nn.functional import swiglu
from paddle.nn.quant import weight_only_linear
def group_gemm(
@@ -71,31 +72,31 @@ def group_gemm(
def iluvatar_moe_expert_ffn(
permute_input: paddle.Tensor,
tokens_expert_prefix_sum: paddle.Tensor,
ffn1_weight: paddle.Tensor,
ffn2_weight: paddle.Tensor,
ffn1_bias: Optional[paddle.Tensor],
ffn1_scale: Optional[paddle.Tensor],
ffn2_scale: Optional[paddle.Tensor],
ffn2_in_scale: Optional[paddle.Tensor],
up_gate_proj_weight: paddle.Tensor,
down_proj_weight: paddle.Tensor,
up_gate_proj_bias: Optional[paddle.Tensor],
up_gate_proj_scale: Optional[paddle.Tensor],
down_proj_scale: Optional[paddle.Tensor],
down_proj_in_scale: Optional[paddle.Tensor],
expert_idx_per_token: Optional[paddle.Tensor],
quant_method: str,
used_in_ep_low_latency: bool,
):
assert ffn1_bias is None
assert ffn1_scale is not None
assert ffn2_scale is not None
assert ffn2_in_scale is None
assert up_gate_proj_bias is None
assert up_gate_proj_scale is not None
assert down_proj_scale is not None
assert down_proj_in_scale is None
assert expert_idx_per_token is None
assert quant_method in ("weight_only_int8")
assert not used_in_ep_low_latency
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
ffn1_output = paddle.empty([permute_input.shape[0], ffn1_weight.shape[1]],
up_gate_proj_output = paddle.empty([permute_input.shape[0], up_gate_proj_weight.shape[1]],
dtype=permute_input.dtype)
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, ffn1_weight,
ffn1_scale, ffn1_output)
act_out = swiglu(ffn1_output)
output = paddle.empty([act_out.shape[0], ffn2_weight.shape[1]],
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, up_gate_proj_weight,
up_gate_proj_scale, up_gate_proj_output)
act_out = swiglu(up_gate_proj_output)
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]],
dtype=act_out.dtype)
group_gemm(act_out, tokens_expert_prefix_sum_cpu, ffn2_weight, ffn2_scale,
group_gemm(act_out, tokens_expert_prefix_sum_cpu, down_proj_weight, down_proj_scale,
output)
return output

View File

@@ -386,19 +386,19 @@ def moe_wint2_ffn_kernel(
def fused_moe_wint2_impl(
hidden_states,
ffn1_quant_weight,
ffn2_quant_weight,
up_gate_proj_quant_weight,
down_proj_quant_weight,
topk_weights,
topk_ids,
# inplace: bool = False,
ffn1_weight_scale=None,
ffn2_weight_scale=None,
ffn1_super_scales=None,
ffn2_super_scales=None,
ffn1_code_scale=None,
ffn2_code_scale=None,
ffn1_code_zp=None,
ffn2_code_zp=None,
up_gate_proj_weight_scale=None,
down_proj_weight_scale=None,
up_gate_proj_super_scales=None,
down_proj_super_scales=None,
up_gate_proj_code_scale=None,
down_proj_code_scale=None,
up_gate_proj_code_zp=None,
down_proj_code_zp=None,
group_size=64,
bit="wint2",
):
@@ -408,22 +408,22 @@ def fused_moe_wint2_impl(
# Check constraints.
# A: [M, K]
# B: [E, K, N]
# assert hidden_states.shape[1] == ffn1_weight_scale.shape[1],
# f"Hidden size mismatch, {hidden_states.shape[1]} != {ffn1_quant_weight.shape[1]}"
# assert hidden_states.shape[1] == up_gate_proj_weight_scale.shape[1],
# f"Hidden size mismatch, {hidden_states.shape[1]} != {up_gate_proj_quant_weight.shape[1]}"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert ffn1_quant_weight.is_contiguous(
assert up_gate_proj_quant_weight.is_contiguous(
), "Expert weights1 must be contiguous"
assert ffn2_quant_weight.is_contiguous(
assert down_proj_quant_weight.is_contiguous(
), "Expert weights2 must be contiguous"
assert group_size > 0, "Group size must be greater than 0"
num_tokens, K = hidden_states.shape
E, _, N = ffn1_quant_weight.shape
E, _, N = up_gate_proj_quant_weight.shape
M = num_tokens
if group_size < 0:
group_size = K // ffn1_weight_scale.shape[1]
group_size = K // up_gate_proj_weight_scale.shape[1]
top_k = topk_ids.shape[1]
@@ -448,12 +448,12 @@ def fused_moe_wint2_impl(
invoke_fused_moe_kernel(
A=hidden_states,
B=ffn1_quant_weight,
B=up_gate_proj_quant_weight,
C=intermediate_cache1,
B_scale=ffn1_weight_scale,
B_super_scale=ffn1_super_scales,
B_code_scale=ffn1_code_scale,
B_code_zp=ffn1_code_zp,
B_scale=up_gate_proj_weight_scale,
B_super_scale=up_gate_proj_super_scales,
B_code_scale=up_gate_proj_code_scale,
B_code_zp=up_gate_proj_code_zp,
topk_weights=topk_weights,
topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
@@ -469,12 +469,12 @@ def fused_moe_wint2_impl(
invoke_fused_moe_kernel(
A=intermediate_cache2,
B=ffn2_quant_weight,
B=down_proj_quant_weight,
C=intermediate_cache3,
B_scale=ffn2_weight_scale,
B_super_scale=ffn2_super_scales,
B_code_scale=ffn2_code_scale,
B_code_zp=ffn2_code_zp,
B_scale=down_proj_weight_scale,
B_super_scale=down_proj_super_scales,
B_code_scale=down_proj_code_scale,
B_code_zp=down_proj_code_zp,
topk_weights=topk_weights,
topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
@@ -491,37 +491,37 @@ def fused_moe_wint2_impl(
def fused_moe_wint2_triton(
hidden_states,
ffn1_quant_weight,
ffn2_quant_weight,
up_gate_proj_quant_weight,
down_proj_quant_weight,
scores,
gate_correction_bias,
topk,
ffn1_weight_scale,
ffn2_weight_scale,
ffn1_super_scales,
ffn2_super_scales,
ffn1_code_scale,
ffn2_code_scale,
ffn1_code_zp,
ffn2_code_zp,
up_gate_proj_weight_scale,
down_proj_weight_scale,
up_gate_proj_super_scales,
down_proj_super_scales,
up_gate_proj_code_scale,
down_proj_code_scale,
up_gate_proj_code_zp,
down_proj_code_zp,
):
"""
Fuse MoE with WINT2 quantization scheme and Triton backend.
Args:
hidden_states: input tensor.
ffn1_quant_weight: ffn1 weight matrix for experts.
ffn2_quant_weight: ffn2 weight matrix for experts.
up_gate_proj_quant_weight: up_gate_proj weight matrix for experts.
down_proj_quant_weight: down_proj weight matrix for experts.
scores: gate scores.
gate_correction_bias: bias correction for gates.
topk: number of experts to use.
ffn1_weight_scale: scaling factor for ffn1_quant_weight.
ffn2_weight_scale: scaling factor for ffn2_quant_weight.
ffn1_super_scales: super scaling factor for ffn1_scale.
ffn2_super_scales: super scaling factor for ffn2_weight_scale.
ffn1_code_scale: code scaling factor for ffn1_quant_weight.
ffn2_code_scale: code scaling factor for ffn2_quant_weight.
ffn1_code_zp: code zero point for ffn1_quant_weight.
ffn2_code_zp: code zero point for ffn2_quant_weight.
up_gate_proj_weight_scale: scaling factor for up_gate_proj_quant_weight.
down_proj_weight_scale: scaling factor for down_proj_quant_weight.
up_gate_proj_super_scales: super scaling factor for up_gate_proj_scale.
down_proj_super_scales: super scaling factor for down_proj_weight_scale.
up_gate_proj_code_scale: code scaling factor for up_gate_proj_quant_weight.
down_proj_code_scale: code scaling factor for down_proj_quant_weight.
up_gate_proj_code_zp: code zero point for up_gate_proj_quant_weight.
down_proj_code_zp: code zero point for down_proj_quant_weight.
Returns:
output tensor.
"""
@@ -533,17 +533,17 @@ def fused_moe_wint2_triton(
return fused_moe_wint2_impl(
hidden_states,
ffn1_quant_weight,
ffn2_quant_weight,
up_gate_proj_quant_weight,
down_proj_quant_weight,
topk_weights,
topk_ids,
ffn1_weight_scale,
ffn2_weight_scale,
ffn1_super_scales,
ffn2_super_scales,
ffn1_code_scale,
ffn2_code_scale,
ffn1_code_zp,
ffn2_code_zp,
up_gate_proj_weight_scale,
down_proj_weight_scale,
up_gate_proj_super_scales,
down_proj_super_scales,
up_gate_proj_code_scale,
down_proj_code_scale,
up_gate_proj_code_zp,
down_proj_code_zp,
bit="wint2",
)

View File

@@ -24,11 +24,14 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.model_loader import ModelRegistry
from fastdeploy.model_executor.models.ernie4_5_moe import \
Ernie4_5_MoeForCausalLM
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import \
Ernie4_5_VLMoeForConditionalGeneration
from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM
from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM
from fastdeploy.rl.rollout_config import RolloutModelConfig
class RolloutModel(nn.Layer):
"""Main model class for rollout operations, supports multimodal components for train."""
@@ -36,55 +39,26 @@ class RolloutModel(nn.Layer):
"""Initialize with FastDeploy configuration."""
super(RolloutModel, self).__init__()
self.fd_config = rollout_model_config.initialize()
self._init_models()
def _init_models(self):
"""Initialize all model components including multimodal if needed."""
self.is_vl = "VL" in self.fd_config.model_config.architectures[0]
self.rollout_model = self._load_primary_model()
self.rollout_models = [self.rollout_model]
if self.is_vl:
self._init_multimodal_models()
self.rollout_models.extend(
[self.vision_model, self.resampler_model])
def _init_multimodal_models(self):
"""Initialize vision and resampler components for multimodal models."""
# TODO:(gaoziyuan) Implement actual initialization
self.vision_model = nn.Layer()
self.resampler_model = nn.Layer()
def _load_primary_model(self):
"""Load main model from loader based on config."""
if "VL" in self.fd_config.model_config.architectures[0]:
logger.error("Loaded Vision Language model, not support now")
self._init_model()
def _init_model(self):
"""Load model from loader based on config."""
context = paddle.LazyGuard()
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(self.fd_config)
model.eval()
return model
self.rollout_model = model.eval()
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models."""
mappings = {}
for model in self.rollout_models:
mappings.update(
getattr(model, "get_name_mappings_to_training", lambda: {})())
return mappings
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})()
@paddle.no_grad()
def state_dict(self):
"""state_dict"""
all_params = {}
for model in self.rollout_models:
for name, param in model.state_dict().items():
all_params[name] = param
return all_params
return self.rollout_model.state_dict()
class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
@@ -113,98 +87,159 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
# Initialize mapping dictionary
infer_to_train = {}
infer_base_name = "model"
train_base_name = "ernie"
base_name = "ernie"
# Static mappings (non-layer specific)
static_mappings = {
f"{infer_base_name}.embeddings.word_embeddings.weight":
f"{train_base_name}.embed_tokens.weight",
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
"lm_head.out_linear.weight": "lm_head.weight"
f"{base_name}.embed_tokens.embeddings.weight":
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
if self.fd_config.model_config.get("weight_sharing", False):
if self.fd_config.model_config.get("tie_word_embeddings", False):
# Support tie_word_embeddings
logger.debug("enable tie_word_embeddings")
static_mappings.pop("lm_head.out_linear.weight")
static_mappings.pop("lm_head.linear.weight")
infer_to_train.update(static_mappings)
infer_base_name = infer_base_name + ".hidden_layers"
train_base_name = train_base_name + ".layers"
base_name = base_name + ".layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx, is_moe_layer=False):
# Handle special case for layer 0's input layernorm
for ph in place_holders:
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
infer_to_train[infer_key] = train_key
# Common attention mappings
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
# Post-attention layernorm
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
if not is_moe_layer:
# Dense FFN mappings
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
else:
def _add_layer_mappings(layer_idx: int):
# MoE specific mappings
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
f"{train_base_name}.{layer_idx}.mlp.gate.weight"
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
f"{base_name}.{layer_idx}.mlp.gate.weight"
if self.fd_config.model_config.moe_use_aux_free:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
# Support shared experts
if self.fd_config.model_config.get(
"moe_num_shared_experts") > 0:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \
f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \
f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight"
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
# MoE experts mappings
for expert_idx in range(self.fd_config.model_config.moe_num_experts):
for ph in place_holders:
# FFN1 (up_gate_proj)
ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn1_weight"
if ffn1_key not in infer_to_train:
infer_to_train[ffn1_key] = []
infer_to_train[ffn1_key].append(
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
# up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.up_gate_proj_weight"
if up_gate_proj_key not in infer_to_train:
infer_to_train[up_gate_proj_key] = []
infer_to_train[up_gate_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
)
# FFN2 (down_proj)
ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.fused_moe.moe_ffn2_weight"
if ffn2_key not in infer_to_train:
infer_to_train[ffn2_key] = []
infer_to_train[ffn2_key].append(
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
# down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.down_proj_weight"
if down_proj_key not in infer_to_train:
infer_to_train[down_proj_key] = []
infer_to_train[down_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
)
# Process non-MoE layers
for layer_idx in range(
self.fd_config.model_config.moe_layer_start_index):
_add_layer_mappings(layer_idx, is_moe_layer=False)
assert isinstance(self.fd_config.model_config.moe_layer_start_index, int)
# Process MoE layers
for layer_idx in range(self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers):
_add_layer_mappings(layer_idx, is_moe_layer=True)
_add_layer_mappings(layer_idx)
return infer_to_train
class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGeneration):
"""
Ernie4_5_VLMoeForConditionalGenerationRL
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Ernie4_5_VLMoeForConditionalGenerationRL, self).__init__(fd_config)
@classmethod
def name(self):
"""name"""
return "Ernie4_5_VLMoeForConditionalGenerationRL"
def get_name_mappings_to_training(self):
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
# Prepare placeholders
place_holders = ["weight"] + (["bias"] if have_bias else [])
# Initialize mapping dictionary
infer_to_train = {}
base_name = "ernie"
# Static mappings (non-layer specific)
static_mappings = {
f"{base_name}.embed_tokens.embeddings.weight":
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
if self.fd_config.model_config.get("tie_word_embeddings", False):
# Support tie_word_embeddings
logger.debug("enable tie_word_embeddings")
static_mappings.pop("lm_head.linear.weight")
infer_to_train.update(static_mappings)
base_name = base_name + ".layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx: int, moe_tag: str):
# MoE specific mappings
infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = f"{base_name}.{layer_idx}.mlp.gate.weight" if moe_tag == "text" else f"{base_name}.{layer_idx}.mlp.gate.weight_1"
if self.fd_config.model_config.moe_use_aux_free:
infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias"] = \
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
# MoE experts mappings
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
if moe_tag == "text":
expert_idx_start = 0
expert_idx_end = self.fd_config.model_config.moe_num_experts[0]
else:
expert_idx_start = self.fd_config.model_config.moe_num_experts[0]
expert_idx_end = self.fd_config.model_config.moe_num_experts[1]
for expert_idx in range(expert_idx_start, expert_idx_end):
for ph in place_holders:
# up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"
if up_gate_proj_key not in infer_to_train:
infer_to_train[up_gate_proj_key] = []
infer_to_train[up_gate_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
)
# down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight"
if down_proj_key not in infer_to_train:
infer_to_train[down_proj_key] = []
infer_to_train[down_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
)
moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
if isinstance(moe_layer_start_index, int):
text_moe_layer_start_index = moe_layer_start_index
image_moe_layer_start_index = moe_layer_start_index
else:
text_moe_layer_start_index = moe_layer_start_index[0]
image_moe_layer_start_index = moe_layer_start_index[1]
moe_layer_end_index = self.fd_config.model_config.moe_layer_end_index
if moe_layer_end_index is None:
text_moe_layer_end_index = self.fd_config.model_config.num_hidden_layers
image_moe_layer_end_index = self.fd_config.model_config.num_hidden_layers
elif isinstance(moe_layer_end_index, int):
text_moe_layer_end_index = moe_layer_end_index
image_moe_layer_end_index = moe_layer_end_index
else:
text_moe_layer_end_index = moe_layer_end_index[0]
image_moe_layer_end_index = moe_layer_end_index[1]
# Process MoE layers
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
_add_layer_mappings(layer_idx, "text")
for layer_idx in range(image_moe_layer_start_index, image_moe_layer_end_index):
_add_layer_mappings(layer_idx, "image")
return infer_to_train
@@ -234,48 +269,23 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
# Initialize mapping dictionary
infer_to_train = {}
infer_base_name = "model"
train_base_name = "qwen2"
base_name = "qwen2"
# Static mappings (non-layer specific)
static_mappings = {
f"{infer_base_name}.embeddings.word_embeddings.weight":
f"{train_base_name}.embed_tokens.weight",
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
"lm_head.out_linear.weight": "lm_head.weight"
f"{base_name}.embed_tokens.embeddings.weight":
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
infer_to_train.update(static_mappings)
infer_base_name = infer_base_name + ".layers"
train_base_name = train_base_name + ".layers"
base_name = base_name + ".layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx):
# Handle special case for layer 0's input layernorm and attn o_proj
for ph in place_holders:
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
infer_to_train[infer_key] = train_key
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
# qwen qkv proj need bias
for ph in ["weight", "bias"]:
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
# Post-attention layernorm
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
# FFN mappings
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
infer_to_train[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = \
f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
for layer_idx in range(
self.fd_config.model_config.num_hidden_layers):
@@ -309,95 +319,49 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
# Initialize mapping dictionary
infer_to_train = {}
infer_base_name = "model"
train_base_name = "model"
base_name = "model"
# Static mappings (non-layer specific)
static_mappings = {
f"{infer_base_name}.embeddings.word_embeddings.weight":
f"{train_base_name}.embed_tokens.weight",
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
"lm_head.out_linear.weight": "lm_head.weight"
f"{base_name}.embed_tokens.embeddings.weight":
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
infer_to_train.update(static_mappings)
infer_base_name = infer_base_name + ".layers"
train_base_name = train_base_name + ".layers"
base_name = base_name + ".layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx, is_moe_layer=False):
# Handle special case for layer 0's input layernorm and attn o_proj
for ph in place_holders:
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
infer_to_train[infer_key] = train_key
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
# qwen q_norm/k_norm
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.q_norm.ln_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.q_norm.{ph}"
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.k_norm.ln_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.k_norm.{ph}"
# qwen qkv proj
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
# Post-attention layernorm
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
if not is_moe_layer:
# FFN mappings
for ph in place_holders:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
else:
def _add_layer_mappings(layer_idx: int):
# MoE specific mappings
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_weight"] = \
f"{train_base_name}.{layer_idx}.mlp.gate.weight"
infer_to_train[f"{base_name}.{layer_idx}.mlp.gate_weight"] = \
f"{base_name}.{layer_idx}.mlp.gate.weight"
if self.fd_config.moe_config.moe_use_aux_free:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
# Support shared experts
if self.fd_config.model_config.get(
"moe_num_shared_experts", 0) > 0:
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \
f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \
f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight"
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
# MoE experts mappings
for expert_idx in range(self.fd_config.moe_config.num_experts):
for ph in place_holders:
# FFN1 (up_gate_proj)
ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn1_weight"
if ffn1_key not in infer_to_train:
infer_to_train[ffn1_key] = []
infer_to_train[ffn1_key].append(
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
# up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.up_gate_proj_weight"
if up_gate_proj_key not in infer_to_train:
infer_to_train[up_gate_proj_key] = []
infer_to_train[up_gate_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
)
# FFN2 (down_proj)
ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn2_weight"
if ffn2_key not in infer_to_train:
infer_to_train[ffn2_key] = []
infer_to_train[ffn2_key].append(
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
# down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.down_proj_weight"
if down_proj_key not in infer_to_train:
infer_to_train[down_proj_key] = []
infer_to_train[down_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
)
# Process MoE layers
for layer_idx in range(self.fd_config.model_config.num_hidden_layers):
_add_layer_mappings(layer_idx, is_moe_layer=True)
_add_layer_mappings(layer_idx)
return infer_to_train