[V1 Loader]support param create and load for wint2 and xpu backend (#3581)

* support wint2 backend'

* [V1 Loader]support param create and load for wint2 and xpu backend

* update weight shape name

* update

* update

* update baseline.txt

* update model name

* update baseline.txt

* fix codestyle

* remove debug coode
This commit is contained in:
Zero Rains
2025-08-28 09:49:36 +08:00
committed by GitHub
parent b28a0343a6
commit e37e86b3b8
9 changed files with 307 additions and 326 deletions

View File

@@ -27,17 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger
# TODO(lulinjun): remove this import after supporting all backends
is_supported_moe_backend = None
if current_platform.is_cuda():
from .check_backend_supported import is_supported_moe_backend
def get_moe_method():
"""
return moe method based on device platform
"""
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from .fused_moe_cutlass_backend import CutlassMoEMethod
@@ -152,24 +146,12 @@ class FusedMoE(nn.Layer):
if self.ep_size > 1:
self.quant_method.init_ep(self)
if fd_config.load_config.dynamic_load_weight:
# It's for RL to build model
self.init_moe_weights()
# Merge normal and RL build model
if gate_correction_bias is not None:
self.gate_correction_bias = gate_correction_bias
else:
if gate_correction_bias is not None:
self.gate_correction_bias = gate_correction_bias
else:
self.gate_correction_bias = None
if moe_quant_config:
if (
moe_quant_config
and is_supported_moe_backend is not None
and is_supported_moe_backend(self.quant_method)
):
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
else:
# w_fp16 a_fp16
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
self.gate_correction_bias = None
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
logger.info(
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \
@@ -179,7 +161,6 @@ class FusedMoE(nn.Layer):
)
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
from fastdeploy.platforms import current_platform
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
@@ -332,86 +313,6 @@ class FusedMoE(nn.Layer):
for shard_id, weight_name in param_name_maping
]
def init_moe_weights(self):
"""
Initialize the weight shapes and parameters for the MoE layer.
Combines weight shape initialization and parameter creation into a single function.
"""
# Initialize weight shapes
up_gate_proj_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["block_wise_fp8", "wint8"]:
up_gate_proj_weight_shape = [
self.num_local_experts,
up_gate_proj_output_dim,
self.hidden_size,
]
down_proj_weight_shape = [
self.num_local_experts,
self.hidden_size,
self.moe_intermediate_size,
]
else:
up_gate_proj_weight_shape = [
self.num_local_experts,
self.hidden_size,
up_gate_proj_output_dim,
]
down_proj_weight_shape = [
self.num_local_experts,
self.moe_intermediate_size,
self.hidden_size,
]
# Create parameters
if self.moe_quant_type == "block_wise_fp8":
# (TODO:gaoziyuan)
self.weight_dtype = "float8_e4m3fn"
self.init_block_wise_fp8_scale()
elif self.moe_quant_type == "wint8":
self.weight_dtype = "int8"
self.init_weight_only_scale()
# up_gate_proj parameters
self.up_gate_proj_weight = self.create_parameter(
shape=up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
# down_proj parameters
self.down_proj_weight = self.create_parameter(
shape=down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
def init_weight_only_scale(self):
"""
Initialize the weight scale.
"""
self.up_gate_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
dtype=self._dtype,
)
self.down_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size],
dtype=self._dtype,
)
def init_block_wise_fp8_scale(self):
"""
Initialize the weight scale.
"""
self.up_gate_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2 // 128, self.hidden_size // 128],
dtype="float32",
is_bias=False,
)
self.down_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size // 128, self.moe_intermediate_size // 128],
dtype="float32",
is_bias=False,
)
def load_experts_weight(
self,
state_dict: dict,
@@ -560,26 +461,13 @@ class FusedMoE(nn.Layer):
"""
load_state_dict function.
"""
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
else:
self.quant_method.process_loaded_weights(self, state_dict)
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
else:
self.quant_method.process_loaded_weights(self, state_dict)
else:
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
else:
self.quant_method.create_weights(self, state_dict)
else:
if self.moe_quant_config:
self.quant_method.create_weights(self, state_dict)
else:
# w_fp16 a_fp16
self.quant_method.process_loaded_weights(self, state_dict)
self.quant_method.process_loaded_weights(self, state_dict)
def forward(self, x: paddle.Tensor, gate: nn.Layer):
"""