mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[V1 Loader]support param create and load for wint2 and xpu backend (#3581)
* support wint2 backend' * [V1 Loader]support param create and load for wint2 and xpu backend * update weight shape name * update * update * update baseline.txt * update model name * update baseline.txt * fix codestyle * remove debug coode
This commit is contained in:
@@ -27,17 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
# TODO(lulinjun): remove this import after supporting all backends
|
||||
is_supported_moe_backend = None
|
||||
if current_platform.is_cuda():
|
||||
from .check_backend_supported import is_supported_moe_backend
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
return moe method based on device platform
|
||||
"""
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
@@ -152,24 +146,12 @@ class FusedMoE(nn.Layer):
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# It's for RL to build model
|
||||
self.init_moe_weights()
|
||||
# Merge normal and RL build model
|
||||
if gate_correction_bias is not None:
|
||||
self.gate_correction_bias = gate_correction_bias
|
||||
else:
|
||||
if gate_correction_bias is not None:
|
||||
self.gate_correction_bias = gate_correction_bias
|
||||
else:
|
||||
self.gate_correction_bias = None
|
||||
if moe_quant_config:
|
||||
if (
|
||||
moe_quant_config
|
||||
and is_supported_moe_backend is not None
|
||||
and is_supported_moe_backend(self.quant_method)
|
||||
):
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
self.gate_correction_bias = None
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
|
||||
logger.info(
|
||||
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \
|
||||
@@ -179,7 +161,6 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
|
||||
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
|
||||
@@ -332,86 +313,6 @@ class FusedMoE(nn.Layer):
|
||||
for shard_id, weight_name in param_name_maping
|
||||
]
|
||||
|
||||
def init_moe_weights(self):
|
||||
"""
|
||||
Initialize the weight shapes and parameters for the MoE layer.
|
||||
Combines weight shape initialization and parameter creation into a single function.
|
||||
"""
|
||||
# Initialize weight shapes
|
||||
up_gate_proj_output_dim = self.moe_intermediate_size * 2
|
||||
if self.moe_quant_type in ["block_wise_fp8", "wint8"]:
|
||||
up_gate_proj_weight_shape = [
|
||||
self.num_local_experts,
|
||||
up_gate_proj_output_dim,
|
||||
self.hidden_size,
|
||||
]
|
||||
down_proj_weight_shape = [
|
||||
self.num_local_experts,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
]
|
||||
else:
|
||||
up_gate_proj_weight_shape = [
|
||||
self.num_local_experts,
|
||||
self.hidden_size,
|
||||
up_gate_proj_output_dim,
|
||||
]
|
||||
down_proj_weight_shape = [
|
||||
self.num_local_experts,
|
||||
self.moe_intermediate_size,
|
||||
self.hidden_size,
|
||||
]
|
||||
|
||||
# Create parameters
|
||||
if self.moe_quant_type == "block_wise_fp8":
|
||||
# (TODO:gaoziyuan)
|
||||
self.weight_dtype = "float8_e4m3fn"
|
||||
self.init_block_wise_fp8_scale()
|
||||
elif self.moe_quant_type == "wint8":
|
||||
self.weight_dtype = "int8"
|
||||
self.init_weight_only_scale()
|
||||
|
||||
# up_gate_proj parameters
|
||||
self.up_gate_proj_weight = self.create_parameter(
|
||||
shape=up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
# down_proj parameters
|
||||
self.down_proj_weight = self.create_parameter(
|
||||
shape=down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def init_weight_only_scale(self):
|
||||
"""
|
||||
Initialize the weight scale.
|
||||
"""
|
||||
self.up_gate_proj_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.down_proj_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
def init_block_wise_fp8_scale(self):
|
||||
"""
|
||||
Initialize the weight scale.
|
||||
"""
|
||||
self.up_gate_proj_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.moe_intermediate_size * 2 // 128, self.hidden_size // 128],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
self.down_proj_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.hidden_size // 128, self.moe_intermediate_size // 128],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def load_experts_weight(
|
||||
self,
|
||||
state_dict: dict,
|
||||
@@ -560,26 +461,13 @@ class FusedMoE(nn.Layer):
|
||||
"""
|
||||
load_state_dict function.
|
||||
"""
|
||||
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
else:
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||
else:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
if self.moe_quant_config:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user