mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 11:02:01 +08:00
Move create_parameters to __init__ in FuseMOE for CultassBackend and TritonBackend (#3148)
* w4a8 bug * fix w4a8 bug * remove code * modify the triton backend * fix ep * fix the bug with tensor_wise_fp8 in triton backend * fix the RL * fix bug by merge * fix the bug in w4a8 * fix the tensor_wise_fp8 bug * fix RL
This commit is contained in:
@@ -22,8 +22,14 @@ from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
# TODO(lulinjun): remove this import after supporting all backends
|
||||
is_supported_moe_backend = None
|
||||
if current_platform.is_cuda():
|
||||
from .check_backend_supported import is_supported_moe_backend
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
@@ -121,10 +127,7 @@ class FusedMoE(nn.Layer):
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method = get_moe_method()
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
|
||||
self.redundant_table_manger = None
|
||||
if self.ep_size > 1:
|
||||
if fd_config.model_config.enable_redundant_experts is True:
|
||||
@@ -139,6 +142,20 @@ class FusedMoE(nn.Layer):
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# It's for RL to build model
|
||||
self.init_moe_weights()
|
||||
else:
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None:
|
||||
self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32")
|
||||
if moe_quant_config:
|
||||
if (
|
||||
moe_quant_config
|
||||
and is_supported_moe_backend is not None
|
||||
and is_supported_moe_backend(self.quant_method)
|
||||
):
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
|
||||
logger.info(
|
||||
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \
|
||||
@@ -475,23 +492,33 @@ class FusedMoE(nn.Layer):
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||
self.gate_correction_bias_key, state_dict
|
||||
)
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
else:
|
||||
self.gate_correction_bias = None
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
else:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
if self.moe_quant_config:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
self.gate_correction_bias = None
|
||||
|
||||
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
else:
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
else:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
if self.moe_quant_config:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user