Move create_parameters to __init__ in FuseMOE for CultassBackend and TritonBackend (#3148)

* w4a8 bug

* fix w4a8 bug

* remove code

* modify the triton backend

* fix ep

* fix the bug with tensor_wise_fp8 in triton backend

* fix the RL

* fix bug by merge

* fix the bug in w4a8

* fix the tensor_wise_fp8 bug

* fix RL
This commit is contained in:
Zero Rains
2025-08-08 15:55:47 +08:00
committed by GitHub
parent d0e9a70380
commit ce1f353c70
10 changed files with 444 additions and 83 deletions

View File

@@ -22,8 +22,14 @@ from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger
# TODO(lulinjun): remove this import after supporting all backends
is_supported_moe_backend = None
if current_platform.is_cuda():
from .check_backend_supported import is_supported_moe_backend
def get_moe_method():
"""
@@ -121,10 +127,7 @@ class FusedMoE(nn.Layer):
self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name()
else:
# w_fp16 a_fp16
self.quant_method = get_moe_method()
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
self.redundant_table_manger = None
if self.ep_size > 1:
if fd_config.model_config.enable_redundant_experts is True:
@@ -139,6 +142,20 @@ class FusedMoE(nn.Layer):
if fd_config.load_config.dynamic_load_weight:
# It's for RL to build model
self.init_moe_weights()
else:
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None:
self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32")
if moe_quant_config:
if (
moe_quant_config
and is_supported_moe_backend is not None
and is_supported_moe_backend(self.quant_method)
):
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
else:
# w_fp16 a_fp16
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
logger.info(
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \
@@ -475,23 +492,33 @@ class FusedMoE(nn.Layer):
gate_correction_bias_tensor = self.extract_gate_correction_bias(
self.gate_correction_bias_key, state_dict
)
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_tensor.shape,
dtype="float32",
)
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
else:
self.gate_correction_bias = None
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.quant_method.create_weights(self, state_dict)
else:
if self.moe_quant_config:
self.quant_method.create_weights(self, state_dict)
self.gate_correction_bias = None
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.quant_method.process_loaded_weights(self, state_dict)
else:
# w_fp16 a_fp16
self.quant_method.process_loaded_weights(self, state_dict)
else:
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.quant_method.create_weights(self, state_dict)
else:
if self.moe_quant_config:
self.quant_method.create_weights(self, state_dict)
else:
# w_fp16 a_fp16
self.quant_method.process_loaded_weights(self, state_dict)
def forward(self, x: paddle.Tensor, gate: nn.Layer):
"""