mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
qwen3_moe (#3084)
This commit is contained in:
@@ -663,7 +663,7 @@ class LoadChoices(str, Enum):
|
|||||||
|
|
||||||
DEFAULT = "default"
|
DEFAULT = "default"
|
||||||
# only support qwen3-bf16 now
|
# only support qwen3-bf16 now
|
||||||
NEW_LOADER = "new_loader"
|
DEFAULT_V1 = "default_v1"
|
||||||
|
|
||||||
|
|
||||||
class LoadConfig:
|
class LoadConfig:
|
||||||
|
@@ -22,7 +22,9 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
|
||||||
|
UnquantizedFusedMoEMethod,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.layers.utils import (
|
from fastdeploy.model_executor.layers.utils import (
|
||||||
CpuGuard,
|
CpuGuard,
|
||||||
create_and_set_parameter,
|
create_and_set_parameter,
|
||||||
@@ -37,7 +39,7 @@ from fastdeploy.model_executor.ops.gcu import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GCUFusedMoeMethod(MoEMethodBase):
|
class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||||
"""
|
"""
|
||||||
Use GCU to compute Fused MoE.
|
Use GCU to compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
@@ -46,28 +48,12 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
|||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
self.group_size = -1
|
self.group_size = -1
|
||||||
|
|
||||||
def create_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
|
||||||
Paddle gcu create weight process.
|
|
||||||
"""
|
|
||||||
# bf16
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||||
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
layer.up_gate_proj_weight.set_value(paddle.transpose(stacked_up_gate_proj_weights, [0, 2, 1]))
|
||||||
# shape [E, K, N] -> [E, N, K]
|
layer.down_proj_weight.set_value(paddle.transpose(stacked_down_proj_weights, [0, 2, 1]))
|
||||||
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
|
|
||||||
weight_name = self.added_weight_attrs[idx]
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
weight_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=weight_tensor.shape,
|
|
||||||
dtype=weight_tensor.dtype,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
getattr(layer, weight_name).set_value(weight_tensor)
|
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def compute_ffn(
|
def compute_ffn(
|
||||||
@@ -202,18 +188,19 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle gcu compute Fused MoE.
|
Paddle gcu compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
|
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
|
||||||
|
|
||||||
def apply_ep_prefill(
|
def apply_ep_prefill(
|
||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP prefill method.
|
Apply the EP prefill method.
|
||||||
@@ -224,7 +211,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP decoder method.
|
Apply the EP decoder method.
|
||||||
@@ -235,7 +222,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
@@ -400,9 +387,10 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle gcu compute Fused MoE.
|
Paddle gcu compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
return self.compute_ffn(layer, x, gate_out, enable_quant=True)
|
return self.compute_ffn(layer, x, gate_out, enable_quant=True)
|
||||||
|
@@ -37,7 +37,7 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.group_size = -1
|
self.group_size = -1
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||||
weight_scale_shape = [layer.weight_shape[1]]
|
weight_scale_shape = [layer.weight_shape[1]]
|
||||||
|
|
||||||
@@ -45,6 +45,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
if self.quant_config.name() == "wint4":
|
if self.quant_config.name() == "wint4":
|
||||||
layer.weight_shape[0] //= 2
|
layer.weight_shape[0] //= 2
|
||||||
layer.weight_dtype = "int8"
|
layer.weight_dtype = "int8"
|
||||||
|
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
layer.weight_scale = layer.create_parameter(
|
layer.weight_scale = layer.create_parameter(
|
||||||
shape=weight_scale_shape,
|
shape=weight_scale_shape,
|
||||||
dtype=layer._dtype,
|
dtype=layer._dtype,
|
||||||
|
@@ -35,7 +35,7 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
def create_weights(self, layer: nn.Layer) -> None:
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None:
|
||||||
"""
|
"""
|
||||||
Create weights for linear layer on XPU
|
Create weights for linear layer on XPU
|
||||||
"""
|
"""
|
||||||
@@ -45,6 +45,12 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
|||||||
if self.quant_config.name() == "weight_only_int4":
|
if self.quant_config.name() == "weight_only_int4":
|
||||||
layer.weight_shape[0] //= 2
|
layer.weight_shape[0] //= 2
|
||||||
layer.weight_dtype = "int8"
|
layer.weight_dtype = "int8"
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
layer.weight_scale = layer.create_parameter(
|
layer.weight_scale = layer.create_parameter(
|
||||||
shape=weight_scale_shape,
|
shape=weight_scale_shape,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
|
@@ -21,6 +21,7 @@ from paddle import nn
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
|
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||||
from fastdeploy.model_executor.models.utils import (
|
from fastdeploy.model_executor.models.utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
@@ -30,6 +31,45 @@ from fastdeploy.platforms import current_platform
|
|||||||
from .utils import _set_var_distributed, divide, get_tensor
|
from .utils import _set_var_distributed, divide, get_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedLinearMethod(QuantMethodBase):
|
||||||
|
"""Linear method without quantization."""
|
||||||
|
|
||||||
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
|
"""
|
||||||
|
extra_weight_attrs is a dictionary that may include parameters like:
|
||||||
|
- split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
|
||||||
|
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
|
||||||
|
- weight_loader: a callable or method responsible for loading the weight data
|
||||||
|
"""
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.weight,
|
||||||
|
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
|
||||||
|
)
|
||||||
|
if hasattr(layer, "nranks") and layer.nranks > 0:
|
||||||
|
split_axis = extra_weight_attrs.get("split_axis")
|
||||||
|
_set_var_distributed(layer.weight, split_axis=split_axis)
|
||||||
|
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
|
||||||
|
|
||||||
|
def process_loaded_weights(self, layer, weights) -> None:
|
||||||
|
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||||
|
if layer.weight.dtype != weights.dtype:
|
||||||
|
weights = weights.cast(layer.weight.dtype)
|
||||||
|
layer.weight.set_value(weights)
|
||||||
|
|
||||||
|
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
|
||||||
|
linear_out = paddle.matmul(x, layer.weight)
|
||||||
|
if layer.with_bias:
|
||||||
|
linear_out = paddle.add(linear_out, layer.bias)
|
||||||
|
return linear_out
|
||||||
|
|
||||||
|
|
||||||
class LinearBase(nn.Layer):
|
class LinearBase(nn.Layer):
|
||||||
"""
|
"""
|
||||||
LinearBase Layer.
|
LinearBase Layer.
|
||||||
@@ -44,6 +84,8 @@ class LinearBase(nn.Layer):
|
|||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
add_bias: bool = False,
|
add_bias: bool = False,
|
||||||
skip_quant: bool = False,
|
skip_quant: bool = False,
|
||||||
|
weight_dtype: str = "",
|
||||||
|
weight_key: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
||||||
@@ -81,6 +123,9 @@ class LinearBase(nn.Layer):
|
|||||||
self.add_bias = add_bias
|
self.add_bias = add_bias
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
# key
|
# key
|
||||||
|
if weight_key:
|
||||||
|
self.weight_key = f"{prefix}.{weight_key}"
|
||||||
|
else:
|
||||||
self.weight_key = f"{prefix}.weight"
|
self.weight_key = f"{prefix}.weight"
|
||||||
self.bias_key = f"{prefix}.bias"
|
self.bias_key = f"{prefix}.bias"
|
||||||
self.shift_key = f"{prefix}.shift_bias"
|
self.shift_key = f"{prefix}.shift_bias"
|
||||||
@@ -88,39 +133,21 @@ class LinearBase(nn.Layer):
|
|||||||
self.out_scale_key = f"{prefix}.out_scale"
|
self.out_scale_key = f"{prefix}.out_scale"
|
||||||
|
|
||||||
self._dtype = self._helper.get_default_dtype()
|
self._dtype = self._helper.get_default_dtype()
|
||||||
|
if weight_dtype:
|
||||||
|
self.weight_dtype = weight_dtype
|
||||||
|
elif self.skip_quant:
|
||||||
|
self.weight_dtype = self._dtype
|
||||||
|
else:
|
||||||
self.weight_dtype = self._dtype
|
self.weight_dtype = self._dtype
|
||||||
self.weight_shape = [
|
self.weight_shape = [
|
||||||
self.input_size,
|
self.input_size,
|
||||||
self.output_size,
|
self.output_size,
|
||||||
]
|
]
|
||||||
if fd_config.quant_config:
|
|
||||||
|
if fd_config.quant_config and not skip_quant:
|
||||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||||
if fd_config.model_config.is_quantized:
|
else:
|
||||||
self.weight_key = f"{prefix}.quant_weight"
|
self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod()
|
||||||
self.weight_scale_key = f"{prefix}.weight_scale"
|
|
||||||
self.act_scale_key = f"{prefix}.activation_scale"
|
|
||||||
|
|
||||||
def init_weight(self):
|
|
||||||
"""
|
|
||||||
Initialize the weights and biases.
|
|
||||||
"""
|
|
||||||
if self.skip_quant:
|
|
||||||
self.weight_dtype = self._dtype
|
|
||||||
self.weight = self.create_parameter(
|
|
||||||
shape=self.weight_shape,
|
|
||||||
dtype=self.weight_dtype,
|
|
||||||
is_bias=False,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
set_weight_attrs(
|
|
||||||
self.weight,
|
|
||||||
{
|
|
||||||
"weight_loader": (
|
|
||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bias = None
|
self.bias = None
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
@@ -130,19 +157,15 @@ class LinearBase(nn.Layer):
|
|||||||
is_bias=True,
|
is_bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
set_weight_attrs(
|
|
||||||
self.weight,
|
|
||||||
{
|
|
||||||
"weight_loader": (
|
|
||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# smooth quant
|
# smooth quant
|
||||||
self.linear_shift = None
|
self.linear_shift = None
|
||||||
self.linear_smooth = None
|
self.linear_smooth = None
|
||||||
|
|
||||||
|
if fd_config.model_config.is_quantized:
|
||||||
|
self.weight_key = f"{prefix}.quant_weight"
|
||||||
|
self.weight_scale_key = f"{prefix}.weight_scale"
|
||||||
|
self.act_scale_key = f"{prefix}.activation_scale"
|
||||||
|
|
||||||
def load_prequant_weight(self, state_dict: dict):
|
def load_prequant_weight(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
Load the prequantized weight from the state dictionary.
|
Load the prequantized weight from the state dictionary.
|
||||||
@@ -160,11 +183,7 @@ class LinearBase(nn.Layer):
|
|||||||
state_dict (dict): A dictionary containing the weights
|
state_dict (dict): A dictionary containing the weights
|
||||||
"""
|
"""
|
||||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||||
|
|
||||||
if self.fd_config.quant_config:
|
|
||||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||||
else:
|
|
||||||
self.weight.set_value(weight_tensor)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict):
|
def load_state_dict(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
@@ -199,12 +218,7 @@ class LinearBase(nn.Layer):
|
|||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
|
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
|
||||||
"""
|
"""
|
||||||
if self.fd_config.quant_config:
|
|
||||||
linear_out = self.quant_method.apply(self, x)
|
linear_out = self.quant_method.apply(self, x)
|
||||||
else:
|
|
||||||
linear_out = paddle.matmul(x, self.weight)
|
|
||||||
if self.with_bias:
|
|
||||||
linear_out = paddle.add(linear_out, self.bias)
|
|
||||||
|
|
||||||
return linear_out
|
return linear_out
|
||||||
|
|
||||||
@@ -223,6 +237,8 @@ class ReplicatedLinear(LinearBase):
|
|||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
add_bias: bool = False,
|
add_bias: bool = False,
|
||||||
skip_quant: bool = False,
|
skip_quant: bool = False,
|
||||||
|
weight_dtype: str = "",
|
||||||
|
weight_key: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a replicated linear layer.
|
Initializes a replicated linear layer.
|
||||||
@@ -245,6 +261,8 @@ class ReplicatedLinear(LinearBase):
|
|||||||
with_bias=with_bias,
|
with_bias=with_bias,
|
||||||
add_bias=add_bias,
|
add_bias=add_bias,
|
||||||
skip_quant=skip_quant,
|
skip_quant=skip_quant,
|
||||||
|
weight_dtype=weight_dtype,
|
||||||
|
weight_key=weight_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hidden_size = fd_config.model_config.hidden_size
|
self.hidden_size = fd_config.model_config.hidden_size
|
||||||
@@ -252,9 +270,14 @@ class ReplicatedLinear(LinearBase):
|
|||||||
self.input_size,
|
self.input_size,
|
||||||
self.output_size,
|
self.output_size,
|
||||||
]
|
]
|
||||||
if fd_config.quant_config:
|
|
||||||
self.quant_method.create_weights(self)
|
assert self.quant_method is not None
|
||||||
self.init_weight()
|
self.quant_method.create_weights(
|
||||||
|
self,
|
||||||
|
weight_loader=(
|
||||||
|
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinear(LinearBase):
|
class ColumnParallelLinear(LinearBase):
|
||||||
@@ -306,60 +329,22 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
self.input_size,
|
self.input_size,
|
||||||
self.output_size,
|
self.output_size,
|
||||||
]
|
]
|
||||||
if fd_config.quant_config:
|
|
||||||
self.quant_method.create_weights(self)
|
|
||||||
self.init_weight()
|
|
||||||
|
|
||||||
def init_weight(self):
|
assert self.quant_method is not None
|
||||||
"""
|
self.quant_method.create_weights(
|
||||||
Initialize the weights and biases.
|
self,
|
||||||
"""
|
split_axis=1,
|
||||||
if self.skip_quant:
|
output_dim=True,
|
||||||
self.weight_dtype = self._dtype
|
weight_loader=(
|
||||||
self.weight = self.create_parameter(
|
|
||||||
shape=self.weight_shape,
|
|
||||||
dtype=self.weight_dtype,
|
|
||||||
is_bias=False,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
)
|
|
||||||
if self.nranks > 0:
|
|
||||||
# col parallel
|
|
||||||
_set_var_distributed(self.weight, split_axis=1)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.weight,
|
|
||||||
{
|
|
||||||
"output_dim": True,
|
|
||||||
"weight_loader": (
|
|
||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||||
),
|
),
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bias = None
|
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
self.bias = self.create_parameter(
|
|
||||||
shape=[self.output_size],
|
|
||||||
dtype=self._dtype,
|
|
||||||
is_bias=True,
|
|
||||||
)
|
|
||||||
if self.nranks > 0:
|
if self.nranks > 0:
|
||||||
# col parallel
|
# col parallel
|
||||||
_set_var_distributed(self.bias, split_axis=1)
|
_set_var_distributed(self.bias, split_axis=1)
|
||||||
set_weight_attrs(
|
set_weight_attrs(self.bias, {"output_dim": True})
|
||||||
self.weight,
|
|
||||||
{
|
|
||||||
"output_dim": True,
|
|
||||||
"weight_loader": (
|
|
||||||
self.weight_loader
|
|
||||||
if hasattr(self, "weight_loader")
|
|
||||||
else default_weight_loader(self.fd_config)
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# smooth quant
|
|
||||||
self.linear_shift = None
|
|
||||||
self.linear_smooth = None
|
|
||||||
|
|
||||||
|
|
||||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||||
@@ -429,9 +414,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
|
||||||
if loaded_shard_id == "gate":
|
if loaded_shard_id == "gate":
|
||||||
param[:, : self.output_size // 2] = loaded_weight
|
param = param[:, : self.output_size // 2]
|
||||||
elif loaded_shard_id == "up":
|
elif loaded_shard_id == "up":
|
||||||
param[:, self.output_size // 2 :] = loaded_weight
|
param = param[:, self.output_size // 2 :]
|
||||||
|
|
||||||
|
assert param.shape == loaded_weight.shape, (
|
||||||
|
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||||
|
)
|
||||||
|
param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict):
|
def load_state_dict(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
@@ -518,16 +508,21 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
|
||||||
if loaded_shard_id == "q":
|
if loaded_shard_id == "q":
|
||||||
param[:, : self.num_heads_per_rank * self.head_dim] = loaded_weight
|
param = param[:, : self.num_heads_per_rank * self.head_dim]
|
||||||
elif loaded_shard_id == "k":
|
elif loaded_shard_id == "k":
|
||||||
param[
|
param = param[
|
||||||
:,
|
:,
|
||||||
self.num_heads_per_rank
|
self.num_heads_per_rank
|
||||||
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
|
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
|
||||||
* self.head_dim,
|
* self.head_dim,
|
||||||
] = loaded_weight
|
]
|
||||||
elif loaded_shard_id == "v":
|
elif loaded_shard_id == "v":
|
||||||
param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :] = loaded_weight
|
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
|
||||||
|
|
||||||
|
assert param.shape == loaded_weight.shape, (
|
||||||
|
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||||
|
)
|
||||||
|
param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
def load_weight(self, state_dict: dict):
|
def load_weight(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
@@ -665,62 +660,25 @@ class RowParallelLinear(LinearBase):
|
|||||||
]
|
]
|
||||||
self._dtype = self._helper.get_default_dtype()
|
self._dtype = self._helper.get_default_dtype()
|
||||||
|
|
||||||
if fd_config.quant_config:
|
assert self.quant_method is not None
|
||||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
self.quant_method.create_weights(
|
||||||
self.quant_method.create_weights(self)
|
self,
|
||||||
|
split_axis=0,
|
||||||
self.reduce_results = reduce_results
|
output_dim=False,
|
||||||
self.init_weight()
|
weight_loader=(
|
||||||
|
|
||||||
def init_weight(self):
|
|
||||||
"""
|
|
||||||
Initialize the weights and biases.
|
|
||||||
"""
|
|
||||||
if self.skip_quant:
|
|
||||||
self.weight_dtype = self._dtype
|
|
||||||
|
|
||||||
self.weight = self.create_parameter(
|
|
||||||
shape=self.weight_shape,
|
|
||||||
dtype=self.weight_dtype,
|
|
||||||
is_bias=False,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
)
|
|
||||||
if self.nranks > 0:
|
|
||||||
# row parallel
|
|
||||||
set_weight_attrs(
|
|
||||||
self.weight,
|
|
||||||
{
|
|
||||||
"output_dim": False,
|
|
||||||
"weight_loader": (
|
|
||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||||
),
|
),
|
||||||
},
|
|
||||||
)
|
)
|
||||||
_set_var_distributed(self.weight, split_axis=0)
|
|
||||||
|
|
||||||
self.bias = None
|
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
self.bias = self.create_parameter(
|
_set_var_distributed(self.bias, split_axis=0)
|
||||||
shape=[self.hidden_size],
|
|
||||||
dtype=self._dtype,
|
|
||||||
is_bias=True,
|
|
||||||
)
|
|
||||||
if self.nranks > 0:
|
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
self.bias,
|
self.bias,
|
||||||
{
|
{
|
||||||
"output_dim": False,
|
"output_dim": False,
|
||||||
"weight_loader": (
|
|
||||||
self.weight_loader
|
|
||||||
if hasattr(self, "weight_loader")
|
|
||||||
else default_weight_loader(self.fd_config)
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.reduce_results = reduce_results
|
||||||
# smooth quant
|
|
||||||
self.linear_shift = None
|
|
||||||
self.linear_smooth = None
|
|
||||||
|
|
||||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||||
if self.fd_config.quant_config:
|
if self.fd_config.quant_config:
|
||||||
|
@@ -19,6 +19,9 @@ from abc import abstractmethod
|
|||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
from ..quantization.quant_base import QuantMethodBase
|
from ..quantization.quant_base import QuantMethodBase
|
||||||
|
|
||||||
|
|
||||||
@@ -125,7 +128,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP prefill method.
|
Apply the EP prefill method.
|
||||||
@@ -137,7 +140,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP decoder method.
|
Apply the EP decoder method.
|
||||||
@@ -149,7 +152,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
@@ -160,7 +163,7 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
@@ -168,9 +171,35 @@ class MoEMethodBase(QuantMethodBase):
|
|||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||||
self.ep_prefill_runner.clean_low_latency_buffer()
|
self.ep_prefill_runner.clean_low_latency_buffer()
|
||||||
return self.apply_ep_prefill(layer, x, gate_out)
|
return self.apply_ep_prefill(layer, x, gate)
|
||||||
else:
|
else:
|
||||||
self.ep_decoder_runner.clean_low_latency_buffer()
|
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||||
return self.apply_ep_decode(layer, x, gate_out)
|
return self.apply_ep_decode(layer, x, gate)
|
||||||
else:
|
else:
|
||||||
return self.apply_tp(layer, x, gate_out)
|
return self.apply_tp(layer, x, gate)
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||||
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2]
|
||||||
|
self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size]
|
||||||
|
else:
|
||||||
|
self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size]
|
||||||
|
self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size]
|
||||||
|
|
||||||
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
|
shape=self.up_gate_proj_weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.down_proj_weight = layer.create_parameter(
|
||||||
|
shape=self.down_proj_weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs)
|
||||||
|
set_weight_attrs(layer.down_proj_weight, extra_weight_attrs)
|
||||||
|
@@ -24,7 +24,7 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
|
|||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import create_and_set_parameter, get_tensor
|
from ..utils import create_and_set_parameter, get_tensor
|
||||||
from .fused_moe_backend_base import MoEMethodBase
|
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
@@ -64,32 +64,19 @@ def get_moe_scores(
|
|||||||
return scores, topk_values, topk_idx
|
return scores, topk_values, topk_idx
|
||||||
|
|
||||||
|
|
||||||
class CutlassMoEMethod(MoEMethodBase):
|
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||||
"""
|
"""
|
||||||
Use Cutlass Group Gemm to compute Fused MoE.
|
Use Cutlass Group Gemm to compute Fused MoE.
|
||||||
This method is the oldest way to compute MoE in Paddle.
|
This method is the oldest way to compute MoE in Paddle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
|
||||||
Paddle cutlass create weight process.
|
|
||||||
"""
|
|
||||||
# bf16
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||||
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
|
||||||
weight_name = self.added_weight_attrs[idx]
|
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||||
setattr(
|
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||||
layer,
|
|
||||||
weight_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=weight_tensor.shape,
|
|
||||||
dtype=weight_tensor.dtype,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
getattr(layer, weight_name).set_value(weight_tensor)
|
|
||||||
|
|
||||||
def compute_ffn(
|
def compute_ffn(
|
||||||
self,
|
self,
|
||||||
@@ -134,11 +121,12 @@ class CutlassMoEMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP prefill method.
|
Apply the EP prefill method.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
# 1. Select topk experts and weights
|
# 1. Select topk experts and weights
|
||||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||||
# 2. EP Dispatch
|
# 2. EP Dispatch
|
||||||
@@ -206,11 +194,12 @@ class CutlassMoEMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP decoder method.
|
Apply the EP decoder method.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
# 1. Select topk experts and weights
|
# 1. Select topk experts and weights
|
||||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||||
@@ -242,11 +231,12 @@ class CutlassMoEMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
if layer.topk_method == "noaux_tc":
|
if layer.topk_method == "noaux_tc":
|
||||||
gate_out, _, _ = get_moe_scores(
|
gate_out, _, _ = get_moe_scores(
|
||||||
gate_out,
|
gate_out,
|
||||||
|
@@ -126,11 +126,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP prefill method.
|
Apply the EP prefill method.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
# 1. Select topk experts and weights
|
# 1. Select topk experts and weights
|
||||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||||
# 2. Dynamic compute blockwise quantization scales
|
# 2. Dynamic compute blockwise quantization scales
|
||||||
@@ -233,11 +234,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP decoder method.
|
Apply the EP decoder method.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
# 1. Select topk experts and weights
|
# 1. Select topk experts and weights
|
||||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||||
# 2. EP Dispatch
|
# 2. EP Dispatch
|
||||||
@@ -303,13 +305,13 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle Use DeepGemm compute Fused MoE.
|
Paddle Use DeepGemm compute Fused MoE.
|
||||||
below is TP compute method.
|
below is TP compute method.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
gate_out,
|
gate_out,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
|
@@ -219,11 +219,12 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Marlin compute Fused MoE.
|
Marlin compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
token_num = x.shape[0]
|
token_num = x.shape[0]
|
||||||
top_k = layer.top_k
|
top_k = layer.top_k
|
||||||
top_k = layer.top_k
|
top_k = layer.top_k
|
||||||
|
@@ -115,11 +115,12 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Triton compute Fused MoE.
|
Triton compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
token_num = x.shape[0]
|
token_num = x.shape[0]
|
||||||
top_k = layer.top_k
|
top_k = layer.top_k
|
||||||
num_local_experts = layer.num_local_experts
|
num_local_experts = layer.num_local_experts
|
||||||
@@ -336,12 +337,12 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Triton compute Fused MoE.
|
Triton compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
token_num = x.shape[0]
|
token_num = x.shape[0]
|
||||||
top_k = layer.top_k
|
top_k = layer.top_k
|
||||||
num_local_experts = layer.num_local_experts
|
num_local_experts = layer.num_local_experts
|
||||||
@@ -576,12 +577,12 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Triton compute Fused MoE.
|
Triton compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
token_num = x.shape[0]
|
token_num = x.shape[0]
|
||||||
top_k = layer.top_k
|
top_k = layer.top_k
|
||||||
num_local_experts = layer.num_local_experts
|
num_local_experts = layer.num_local_experts
|
||||||
|
@@ -171,12 +171,12 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -242,12 +242,12 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||||
"""
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
from fastdeploy.model_executor.ops.triton_ops import moe_wint2_ffn_kernel
|
from fastdeploy.model_executor.ops.triton_ops import moe_wint2_ffn_kernel
|
||||||
|
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
@@ -19,47 +19,36 @@ from typing import Dict
|
|||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
|
||||||
|
UnquantizedFusedMoEMethod,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||||
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
|
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
|
||||||
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
|
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
|
||||||
|
|
||||||
from .fused_moe_backend_base import MoEMethodBase
|
|
||||||
|
|
||||||
|
class XPUMoEMethod(UnquantizedFusedMoEMethod):
|
||||||
class XPUMoEMethod(MoEMethodBase):
|
|
||||||
"""
|
"""
|
||||||
XPU MOE
|
XPU MOE
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
|
||||||
Paddle cutlass create weight process.
|
|
||||||
"""
|
|
||||||
# bf16
|
|
||||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||||
for weights in [up_gate_proj_weights, down_proj_weights]:
|
for weights in [up_gate_proj_weights, down_proj_weights]:
|
||||||
for idx, weight in enumerate(weights):
|
for idx, weight in enumerate(weights):
|
||||||
weights[idx] = weight.transpose([1, 0])
|
weights[idx] = weight.transpose([1, 0])
|
||||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||||
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
|
||||||
weight_name = self.added_weight_attrs[idx]
|
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||||
setattr(
|
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||||
layer,
|
|
||||||
weight_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=weight_tensor.shape,
|
|
||||||
dtype=weight_tensor.dtype,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
getattr(layer, weight_name).set_value(weight_tensor)
|
|
||||||
|
|
||||||
def apply_tp(
|
def apply_tp(
|
||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Paddle Cutlass compute Fused MoE.
|
Paddle Cutlass compute Fused MoE.
|
||||||
@@ -68,7 +57,7 @@ class XPUMoEMethod(MoEMethodBase):
|
|||||||
|
|
||||||
fused_moe_out = xpu_moe_layer(
|
fused_moe_out = xpu_moe_layer(
|
||||||
x,
|
x,
|
||||||
layer.gate_weight.transpose([1, 0]),
|
gate.weight.transpose([1, 0]),
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
layer.down_proj_weight,
|
layer.down_proj_weight,
|
||||||
@@ -94,7 +83,7 @@ class XPUMoEMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP prefill method.
|
Apply the EP prefill method.
|
||||||
@@ -105,7 +94,7 @@ class XPUMoEMethod(MoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the EP decoder method.
|
Apply the EP decoder method.
|
||||||
@@ -187,7 +176,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
x: paddle.Tensor,
|
x: paddle.Tensor,
|
||||||
gate_out: paddle.Tensor,
|
gate: nn.Layer,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
XPU compute Fused MoE.
|
XPU compute Fused MoE.
|
||||||
@@ -196,7 +185,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
|
|
||||||
fused_moe_out = xpu_moe_layer(
|
fused_moe_out = xpu_moe_layer(
|
||||||
x,
|
x,
|
||||||
layer.gate_weight.transpose([1, 0]),
|
gate.weight.transpose([1, 0]),
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
layer.down_proj_weight,
|
layer.down_proj_weight,
|
||||||
|
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
@@ -77,7 +79,7 @@ class FusedMoE(nn.Layer):
|
|||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||||
@@ -109,14 +111,19 @@ class FusedMoE(nn.Layer):
|
|||||||
self.n_group = n_group
|
self.n_group = n_group
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
|
||||||
|
self._dtype = self._helper.get_default_dtype()
|
||||||
|
self.weight_dtype = self._dtype
|
||||||
|
|
||||||
moe_quant_config = fd_config.quant_config
|
moe_quant_config = fd_config.quant_config
|
||||||
|
self.moe_quant_config = moe_quant_config
|
||||||
self.moe_quant_type = None
|
self.moe_quant_type = None
|
||||||
if moe_quant_config:
|
if moe_quant_config:
|
||||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||||
self.moe_quant_type = moe_quant_config.name()
|
self.moe_quant_type = moe_quant_config.name()
|
||||||
else:
|
else:
|
||||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
# w_fp16 a_fp16
|
||||||
self.quant_method = get_moe_method()
|
self.quant_method = get_moe_method()
|
||||||
|
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||||
|
|
||||||
self.redundant_table_manger = None
|
self.redundant_table_manger = None
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
@@ -140,21 +147,121 @@ class FusedMoE(nn.Layer):
|
|||||||
tp_size={self.tp_size}."
|
tp_size={self.tp_size}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
if shard_id is None:
|
||||||
|
# 1.gate up fused in disk
|
||||||
|
return
|
||||||
|
# 2.gate up splited in disk
|
||||||
|
assert shard_id in ["gate", "down", "up"]
|
||||||
|
expert_param = param[expert_id]
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
|
||||||
|
else:
|
||||||
|
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
||||||
|
self._load_expert_weight(
|
||||||
|
expert_param=expert_param,
|
||||||
|
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
shard_id=shard_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
|
||||||
|
tensor_size = expert_param.shape[shard_dim] // 2
|
||||||
|
if shard_id == "gate":
|
||||||
|
expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...]
|
||||||
|
elif shard_id == "up":
|
||||||
|
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
size = loaded_weight.get_shape()[-1]
|
||||||
|
block_size = size // self.tp_size
|
||||||
|
shard_offset = self.tp_rank * block_size
|
||||||
|
shard_size = (self.tp_rank + 1) * block_size
|
||||||
|
loaded_weight = loaded_weight[..., shard_offset:shard_size]
|
||||||
|
|
||||||
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
|
||||||
|
if expert_param.shape != loaded_weight.shape:
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
|
assert expert_param.shape == loaded_weight.shape, (
|
||||||
|
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||||
|
)
|
||||||
|
expert_param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
|
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
|
||||||
|
if self.tp_size > 1:
|
||||||
|
size = loaded_weight.get_shape()[shard_dim]
|
||||||
|
block_size = size // self.tp_size
|
||||||
|
shard_offset = self.tp_rank * block_size
|
||||||
|
shard_size = (self.tp_rank + 1) * block_size
|
||||||
|
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
|
||||||
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
|
||||||
|
if expert_param.shape != loaded_weight.shape:
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
|
assert expert_param.shape == loaded_weight.shape, (
|
||||||
|
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||||
|
)
|
||||||
|
expert_param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
|
def _load_expert_weight(
|
||||||
|
self,
|
||||||
|
expert_param,
|
||||||
|
shard_dim,
|
||||||
|
loaded_weight,
|
||||||
|
shard_id,
|
||||||
|
):
|
||||||
|
if shard_id == "down":
|
||||||
|
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||||
|
elif shard_id in ["gate", "up"]:
|
||||||
|
self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_params_mapping(
|
||||||
|
cls,
|
||||||
|
ckpt_gate_proj_name: str,
|
||||||
|
ckpt_down_proj_name: str,
|
||||||
|
ckpt_up_proj_name: str,
|
||||||
|
param_gate_up_proj_name: str,
|
||||||
|
param_down_proj_name: str,
|
||||||
|
num_experts: int,
|
||||||
|
ckpt_expert_key_name: str = "experts",
|
||||||
|
ckpt_gate_up_proj_name: Optional[str] = None,
|
||||||
|
) -> list[tuple[str, str, int, str]]:
|
||||||
|
param_name_maping = [
|
||||||
|
("gate", ckpt_gate_proj_name),
|
||||||
|
("down", ckpt_down_proj_name),
|
||||||
|
("up", ckpt_up_proj_name),
|
||||||
|
]
|
||||||
|
if ckpt_gate_up_proj_name:
|
||||||
|
param_name_maping.append((None, ckpt_gate_up_proj_name))
|
||||||
|
|
||||||
|
return [
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
(
|
||||||
|
(
|
||||||
|
param_gate_up_proj_name
|
||||||
|
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||||
|
else param_down_proj_name
|
||||||
|
),
|
||||||
|
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in param_name_maping
|
||||||
|
]
|
||||||
|
|
||||||
def init_moe_weights(self):
|
def init_moe_weights(self):
|
||||||
"""
|
"""
|
||||||
Initialize the weight shapes and parameters for the MoE layer.
|
Initialize the weight shapes and parameters for the MoE layer.
|
||||||
Combines weight shape initialization and parameter creation into a single function.
|
Combines weight shape initialization and parameter creation into a single function.
|
||||||
"""
|
"""
|
||||||
# Initialize weight shapes
|
# Initialize weight shapes
|
||||||
self._dtype = self._helper.get_default_dtype()
|
|
||||||
self.weight_dtype = self._dtype
|
|
||||||
gate_weight_shape = [self.hidden_size, self.num_experts]
|
|
||||||
gate_correction_bias_shape = [1, self.num_experts]
|
gate_correction_bias_shape = [1, self.num_experts]
|
||||||
|
|
||||||
self.gate_weight = self.create_parameter(
|
|
||||||
shape=gate_weight_shape,
|
|
||||||
dtype="float32",
|
|
||||||
)
|
|
||||||
if self.fd_config.model_config.moe_use_aux_free:
|
if self.fd_config.model_config.moe_use_aux_free:
|
||||||
self.gate_correction_bias = self.create_parameter(
|
self.gate_correction_bias = self.create_parameter(
|
||||||
shape=gate_correction_bias_shape,
|
shape=gate_correction_bias_shape,
|
||||||
@@ -374,26 +481,19 @@ class FusedMoE(nn.Layer):
|
|||||||
)
|
)
|
||||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||||
|
|
||||||
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
|
|
||||||
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
|
|
||||||
|
|
||||||
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
|
|
||||||
|
|
||||||
self.gate_weight = self.create_parameter(
|
|
||||||
shape=gate_weight_tensor.shape,
|
|
||||||
dtype="float32",
|
|
||||||
)
|
|
||||||
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
|
|
||||||
|
|
||||||
if self.fd_config.model_config.is_quantized:
|
if self.fd_config.model_config.is_quantized:
|
||||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||||
else:
|
else:
|
||||||
self.quant_method.create_weights(self, state_dict)
|
self.quant_method.create_weights(self, state_dict)
|
||||||
else:
|
else:
|
||||||
|
if self.moe_quant_config:
|
||||||
self.quant_method.create_weights(self, state_dict)
|
self.quant_method.create_weights(self, state_dict)
|
||||||
|
else:
|
||||||
|
# w_fp16 a_fp16
|
||||||
|
self.quant_method.process_loaded_weights(self, state_dict)
|
||||||
|
|
||||||
def forward(self, x: paddle.Tensor):
|
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
||||||
"""
|
"""
|
||||||
Defines the forward computation of the moe layer.
|
Defines the forward computation of the moe layer.
|
||||||
|
|
||||||
@@ -404,6 +504,5 @@ class FusedMoE(nn.Layer):
|
|||||||
Tensor: Output tensor.s
|
Tensor: Output tensor.s
|
||||||
|
|
||||||
"""
|
"""
|
||||||
gate_out = paddle.matmul(x.cast("float32"), self.gate_weight)
|
out = self.quant_method.apply(self, x, gate)
|
||||||
out = self.quant_method.apply(self, x, gate_out)
|
|
||||||
return out
|
return out
|
||||||
|
@@ -81,8 +81,16 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
layer.weight_shape.reverse()
|
layer.weight_shape.reverse()
|
||||||
|
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
layer.weight_scale = layer.create_parameter(
|
layer.weight_scale = layer.create_parameter(
|
||||||
shape=[
|
shape=[
|
||||||
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
||||||
|
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||||
|
|
||||||
from ..utils import get_tensor
|
from ..utils import get_tensor
|
||||||
@@ -79,11 +81,14 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
self.quant_round_type = 1
|
self.quant_round_type = 1
|
||||||
self.weight_dtype = "float8_e4m3fn"
|
self.weight_dtype = "float8_e4m3fn"
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
"""
|
|
||||||
Nothing to do!
|
layer.weight = layer.create_parameter(
|
||||||
"""
|
shape=layer.weight_shape,
|
||||||
pass
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||||
"""
|
"""
|
||||||
|
@@ -63,11 +63,17 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
layer.weight_shape.reverse()
|
layer.weight_shape.reverse()
|
||||||
layer.weight_shape[0] //= 2
|
layer.weight_shape[0] //= 2
|
||||||
layer.weight_dtype = "int8"
|
layer.weight_dtype = "int8"
|
||||||
pass
|
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
def process_loaded_weights(self, layer, weights) -> None:
|
def process_loaded_weights(self, layer, weights) -> None:
|
||||||
(
|
(
|
||||||
|
@@ -74,7 +74,7 @@ class W8A8LinearMethod(QuantMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
|
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
layer.weight_shape.reverse()
|
layer.weight_shape.reverse()
|
||||||
layer.weight_dtype = "int8"
|
layer.weight_dtype = "int8"
|
||||||
if self.quant_config.use_smooth_quant:
|
if self.quant_config.use_smooth_quant:
|
||||||
@@ -85,7 +85,12 @@ class W8A8LinearMethod(QuantMethodBase):
|
|||||||
if weight_scale is None or in_scale is None:
|
if weight_scale is None or in_scale is None:
|
||||||
self.skip_quant = True
|
self.skip_quant = True
|
||||||
return
|
return
|
||||||
|
layer.wieght = layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
max_range = 127.0
|
max_range = 127.0
|
||||||
linear_out_scale = paddle.to_tensor(weight_scale / (max_range * max_range * in_scale)).astype("float32")
|
linear_out_scale = paddle.to_tensor(weight_scale / (max_range * max_range * in_scale)).astype("float32")
|
||||||
layer.linear_out_scale = layer.create_parameter(
|
layer.linear_out_scale = layer.create_parameter(
|
||||||
@@ -136,7 +141,7 @@ class SmoothQuantLinearMethod(QuantMethodBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
linear_shift_shape = [layer.output_size]
|
linear_shift_shape = [layer.output_size]
|
||||||
linear_smooth_shape = [layer.output_size]
|
linear_smooth_shape = [layer.output_size]
|
||||||
layer.linear_shift = self.create_parameter(
|
layer.linear_shift = self.create_parameter(
|
||||||
|
@@ -168,7 +168,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
|
|
||||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||||
weight_scale_shape = [layer.weight_shape[1]]
|
weight_scale_shape = [layer.weight_shape[1]]
|
||||||
@@ -177,6 +177,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
if self.quant_config.name() == "wint4":
|
if self.quant_config.name() == "wint4":
|
||||||
layer.weight_shape[0] //= 2
|
layer.weight_shape[0] //= 2
|
||||||
layer.weight_dtype = "int8"
|
layer.weight_dtype = "int8"
|
||||||
|
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
layer.weight_scale = layer.create_parameter(
|
layer.weight_scale = layer.create_parameter(
|
||||||
shape=weight_scale_shape,
|
shape=weight_scale_shape,
|
||||||
dtype=layer._dtype,
|
dtype=layer._dtype,
|
||||||
|
@@ -69,12 +69,18 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
""" """
|
""" """
|
||||||
layer.weight_shape.reverse()
|
layer.weight_shape.reverse()
|
||||||
layer.weight_dtype = "float8_e4m3fn"
|
layer.weight_dtype = "float8_e4m3fn"
|
||||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||||
self.skip_quant = False
|
self.skip_quant = False
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
layer.weight_scale = layer.create_parameter(
|
layer.weight_scale = layer.create_parameter(
|
||||||
shape=[1],
|
shape=[1],
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
|
@@ -17,14 +17,16 @@
|
|||||||
from fastdeploy.config import LoadChoices, LoadConfig
|
from fastdeploy.config import LoadChoices, LoadConfig
|
||||||
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from fastdeploy.model_executor.model_loader.default_loader import DefaultModelLoader
|
from fastdeploy.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||||
from fastdeploy.model_executor.model_loader.new_loader import NewModelLoader
|
from fastdeploy.model_executor.model_loader.default_loader_v1 import (
|
||||||
|
DefaultModelLoaderV1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||||
"""get_model_loader"""
|
"""get_model_loader"""
|
||||||
|
|
||||||
if load_config.load_choices == LoadChoices.NEW_LOADER:
|
if load_config.load_choices == LoadChoices.DEFAULT_V1:
|
||||||
return NewModelLoader(load_config)
|
return DefaultModelLoaderV1(load_config)
|
||||||
|
|
||||||
return DefaultModelLoader(load_config)
|
return DefaultModelLoader(load_config)
|
||||||
|
|
||||||
|
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
@@ -62,15 +64,16 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self.clean_memory_fragments(state_dict)
|
self.clean_memory_fragments(state_dict)
|
||||||
|
|
||||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||||
context = paddle.LazyGuard()
|
|
||||||
architectures = fd_config.model_config.architectures[0]
|
architectures = fd_config.model_config.architectures[0]
|
||||||
logger.info(f"Starting to load model {architectures}")
|
logger.info(f"Starting to load model {architectures}")
|
||||||
|
|
||||||
if fd_config.load_config.dynamic_load_weight:
|
if fd_config.load_config.dynamic_load_weight:
|
||||||
# register rl model
|
# register rl model
|
||||||
import fastdeploy.rl # noqa
|
import fastdeploy.rl # noqa
|
||||||
|
|
||||||
architectures = architectures + "RL"
|
architectures = architectures + "RL"
|
||||||
|
context = paddle.LazyGuard()
|
||||||
|
else:
|
||||||
|
context = contextlib.nullcontext()
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
model_cls = ModelRegistry.get_class(architectures)
|
model_cls = ModelRegistry.get_class(architectures)
|
||||||
|
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
@@ -29,7 +31,7 @@ from fastdeploy.model_executor.models.model_base import ModelRegistry
|
|||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
class NewModelLoader(BaseModelLoader):
|
class DefaultModelLoaderV1(BaseModelLoader):
|
||||||
"""ModelLoader that can load registered models"""
|
"""ModelLoader that can load registered models"""
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
@@ -54,13 +56,17 @@ class NewModelLoader(BaseModelLoader):
|
|||||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||||
architectures = fd_config.model_config.architectures[0]
|
architectures = fd_config.model_config.architectures[0]
|
||||||
logger.info(f"Starting to load model {architectures}")
|
logger.info(f"Starting to load model {architectures}")
|
||||||
|
|
||||||
if fd_config.load_config.dynamic_load_weight:
|
if fd_config.load_config.dynamic_load_weight:
|
||||||
# register rl model
|
# register rl model
|
||||||
import fastdeploy.rl # noqa
|
import fastdeploy.rl # noqa
|
||||||
|
|
||||||
architectures = architectures + "RL"
|
architectures = architectures + "RL"
|
||||||
|
context = paddle.LazyGuard()
|
||||||
|
|
||||||
|
else:
|
||||||
|
context = contextlib.nullcontext()
|
||||||
|
|
||||||
|
with context:
|
||||||
model_cls = ModelRegistry.get_class(architectures)
|
model_cls = ModelRegistry.get_class(architectures)
|
||||||
model = model_cls(fd_config)
|
model = model_cls(fd_config)
|
||||||
|
|
@@ -117,13 +117,12 @@ class DeepSeekV3MoE(nn.Layer):
|
|||||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
weight_key_map = {
|
weight_key_map = {
|
||||||
"gate_weight_key": f"{prefix}.gate.weight",
|
|
||||||
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
||||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.fused_moe = FusedMoE(
|
self.experts = FusedMoE(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||||
@@ -137,6 +136,16 @@ class DeepSeekV3MoE(nn.Layer):
|
|||||||
weight_key_map=weight_key_map,
|
weight_key_map=weight_key_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.gate = ReplicatedLinear(
|
||||||
|
fd_config=fd_config,
|
||||||
|
prefix=f"{prefix}.gate",
|
||||||
|
input_size=fd_config.model_config.hidden_size,
|
||||||
|
output_size=fd_config.model_config.n_routed_experts,
|
||||||
|
with_bias=False,
|
||||||
|
skip_quant=True,
|
||||||
|
weight_dtype="float32",
|
||||||
|
)
|
||||||
|
|
||||||
self.num_shared_experts = fd_config.model_config.n_shared_experts
|
self.num_shared_experts = fd_config.model_config.n_shared_experts
|
||||||
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
|
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
|
||||||
|
|
||||||
@@ -149,13 +158,14 @@ class DeepSeekV3MoE(nn.Layer):
|
|||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
""" """
|
""" """
|
||||||
self.fused_moe.load_state_dict(state_dict)
|
self.gate.load_state_dict(state_dict)
|
||||||
|
self.experts.load_state_dict(state_dict)
|
||||||
self.shared_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor):
|
||||||
""" """
|
""" """
|
||||||
shared_experts_out = self.shared_experts(hidden_states)
|
shared_experts_out = self.shared_experts(hidden_states)
|
||||||
moe_out = self.fused_moe(hidden_states)
|
moe_out = self.experts(hidden_states, self.gate)
|
||||||
moe_out = moe_out + shared_experts_out
|
moe_out = moe_out + shared_experts_out
|
||||||
# We do to TP all reduce after the sum of experts.
|
# We do to TP all reduce after the sum of experts.
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
|
@@ -37,6 +37,7 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
|||||||
from fastdeploy.model_executor.layers.linear import (
|
from fastdeploy.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
@@ -147,7 +148,7 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.fused_moe = FusedMoE(
|
self.experts = FusedMoE(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||||
num_experts=fd_config.model_config.moe_num_experts,
|
num_experts=fd_config.model_config.moe_num_experts,
|
||||||
@@ -156,6 +157,16 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
weight_key_map=weight_key_map,
|
weight_key_map=weight_key_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.gate = ReplicatedLinear(
|
||||||
|
fd_config=fd_config,
|
||||||
|
prefix=f"{prefix}.gate",
|
||||||
|
input_size=fd_config.model_config.hidden_size,
|
||||||
|
output_size=fd_config.model_config.moe_num_experts,
|
||||||
|
with_bias=False,
|
||||||
|
skip_quant=True,
|
||||||
|
weight_dtype="float32",
|
||||||
|
)
|
||||||
|
|
||||||
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
|
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
|
shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
|
||||||
@@ -166,12 +177,13 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.fused_moe.load_state_dict(state_dict)
|
self.gate.load_state_dict(state_dict)
|
||||||
|
self.experts.load_state_dict(state_dict)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
self.shared_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor):
|
||||||
out = self.fused_moe(hidden_states)
|
out = self.experts(hidden_states, self.gate)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
s_x = self.shared_experts(hidden_states)
|
s_x = self.shared_experts(hidden_states)
|
||||||
out = out + s_x
|
out = out + s_x
|
||||||
@@ -435,7 +447,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.moe_layer_start_index,
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states)
|
self.ernie.layers[i].mlp.expert(fake_hidden_states)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@@ -33,6 +33,7 @@ from fastdeploy.model_executor.graph_optimization.decorator import (
|
|||||||
support_graph_optimization,
|
support_graph_optimization,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||||
|
from fastdeploy.model_executor.layers.linear import ReplicatedLinear
|
||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
@@ -73,6 +74,93 @@ class VLMoEMeta:
|
|||||||
fake_hidden_states: Optional[paddle.Tensor] = None
|
fake_hidden_states: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Ernie4_5_VLMoeBlock(nn.Layer):
|
||||||
|
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str, moe_tag: str, expert_id_offset: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
moe_quant_type = ""
|
||||||
|
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
|
||||||
|
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
|
||||||
|
|
||||||
|
if moe_quant_type == "tensor_wise_fp8" or (
|
||||||
|
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
|
||||||
|
):
|
||||||
|
weight_key_map = {
|
||||||
|
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
||||||
|
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
|
||||||
|
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
|
||||||
|
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
|
||||||
|
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
|
||||||
|
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
|
||||||
|
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# wint4/wint8/bfloat16
|
||||||
|
weight_key_map = {
|
||||||
|
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
||||||
|
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||||
|
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||||
|
}
|
||||||
|
moe_intermediate_size = (
|
||||||
|
fd_config.model_config.moe_intermediate_size[0]
|
||||||
|
if moe_tag == "Text"
|
||||||
|
else fd_config.model_config.moe_intermediate_size[1]
|
||||||
|
)
|
||||||
|
num_experts = (
|
||||||
|
fd_config.model_config.moe_num_experts[0]
|
||||||
|
if moe_tag == "Text"
|
||||||
|
else fd_config.model_config.moe_num_experts[1]
|
||||||
|
)
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
fd_config=fd_config,
|
||||||
|
reduce_results=False,
|
||||||
|
moe_intermediate_size=moe_intermediate_size,
|
||||||
|
num_experts=num_experts,
|
||||||
|
expert_id_offset=expert_id_offset,
|
||||||
|
top_k=fd_config.model_config.moe_k,
|
||||||
|
layer_idx=layer_id,
|
||||||
|
moe_tag=moe_tag,
|
||||||
|
weight_key_map=weight_key_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gate = ReplicatedLinear(
|
||||||
|
fd_config=fd_config,
|
||||||
|
prefix=f"{prefix}.gate",
|
||||||
|
input_size=fd_config.model_config.hidden_size,
|
||||||
|
output_size=num_experts,
|
||||||
|
with_bias=False,
|
||||||
|
skip_quant=True,
|
||||||
|
weight_dtype="float32",
|
||||||
|
weight_key="weight" if moe_tag == "Text" else "weight_1",
|
||||||
|
)
|
||||||
|
|
||||||
|
if moe_tag == "Text":
|
||||||
|
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_text
|
||||||
|
elif moe_tag == "Image":
|
||||||
|
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_image
|
||||||
|
|
||||||
|
def forward(self, hidden_states: paddle.Tensor):
|
||||||
|
out = self.experts(hidden_states, self.gate)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict):
|
||||||
|
"""
|
||||||
|
extract_gate_correction_bias function.
|
||||||
|
"""
|
||||||
|
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
|
||||||
|
return gate_correction_bias_tensor[0].unsqueeze(0)
|
||||||
|
|
||||||
|
def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict):
|
||||||
|
"""
|
||||||
|
extract_gate_correction_bias function.
|
||||||
|
"""
|
||||||
|
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
|
||||||
|
return gate_correction_bias_tensor[1].unsqueeze(0)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.experts.load_state_dict(state_dict)
|
||||||
|
self.gate.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_VLMoE(nn.Layer):
|
class Ernie4_5_VLMoE(nn.Layer):
|
||||||
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
|
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -99,43 +187,10 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
|
|
||||||
assert text_moe_layer_start_index <= text_moe_layer_end_index
|
assert text_moe_layer_start_index <= text_moe_layer_end_index
|
||||||
|
|
||||||
moe_quant_type = ""
|
|
||||||
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
|
|
||||||
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
|
|
||||||
|
|
||||||
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
|
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
|
||||||
if moe_quant_type == "tensor_wise_fp8" or (
|
self.text_fused_moe = Ernie4_5_VLMoeBlock(
|
||||||
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
|
fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}", moe_tag="Text", expert_id_offset=0
|
||||||
):
|
|
||||||
weight_key_map = {
|
|
||||||
"gate_weight_key": f"{prefix}.gate.weight",
|
|
||||||
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
|
||||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
|
|
||||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
|
|
||||||
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
|
|
||||||
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
|
|
||||||
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
|
|
||||||
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
weight_key_map = {
|
|
||||||
"gate_weight_key": f"{prefix}.gate.weight",
|
|
||||||
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
|
||||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
|
||||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
|
||||||
}
|
|
||||||
self.text_fused_moe = FusedMoE(
|
|
||||||
fd_config=fd_config,
|
|
||||||
reduce_results=False,
|
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size[0],
|
|
||||||
num_experts=fd_config.model_config.moe_num_experts[0],
|
|
||||||
expert_id_offset=0,
|
|
||||||
top_k=fd_config.model_config.moe_k,
|
|
||||||
layer_idx=layer_id,
|
|
||||||
moe_tag="Text",
|
|
||||||
weight_key_map=weight_key_map,
|
|
||||||
)
|
)
|
||||||
self.text_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_text
|
|
||||||
else:
|
else:
|
||||||
self.text_fused_moe = Ernie4_5_VLMLP(
|
self.text_fused_moe = Ernie4_5_VLMLP(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
@@ -146,38 +201,13 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
|
|
||||||
assert image_moe_layer_start_index <= image_moe_layer_end_index
|
assert image_moe_layer_start_index <= image_moe_layer_end_index
|
||||||
if layer_id >= image_moe_layer_start_index and layer_id <= image_moe_layer_end_index:
|
if layer_id >= image_moe_layer_start_index and layer_id <= image_moe_layer_end_index:
|
||||||
if moe_quant_type == "tensor_wise_fp8" or (
|
self.image_fused_moe = Ernie4_5_VLMoeBlock(
|
||||||
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
|
|
||||||
):
|
|
||||||
weight_key_map = {
|
|
||||||
"gate_weight_key": f"{prefix}.gate.weight_1",
|
|
||||||
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
|
||||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
|
|
||||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
|
|
||||||
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
|
|
||||||
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
|
|
||||||
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
|
|
||||||
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
weight_key_map = {
|
|
||||||
"gate_weight_key": f"{prefix}.gate.weight_1",
|
|
||||||
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
|
||||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
|
||||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
|
||||||
}
|
|
||||||
self.image_fused_moe = FusedMoE(
|
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
reduce_results=False,
|
layer_id=layer_id,
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size[1],
|
prefix=f"{prefix}",
|
||||||
num_experts=fd_config.model_config.moe_num_experts[1],
|
|
||||||
expert_id_offset=fd_config.model_config.moe_num_experts[0],
|
|
||||||
top_k=fd_config.model_config.moe_k,
|
|
||||||
layer_idx=layer_id,
|
|
||||||
moe_tag="Image",
|
moe_tag="Image",
|
||||||
weight_key_map=weight_key_map,
|
expert_id_offset=fd_config.model_config.moe_num_experts[0],
|
||||||
)
|
)
|
||||||
self.image_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_image
|
|
||||||
else:
|
else:
|
||||||
self.image_fused_moe = Ernie4_5_VLMLP(
|
self.image_fused_moe = Ernie4_5_VLMLP(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
@@ -195,25 +225,11 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict):
|
|
||||||
"""
|
|
||||||
extract_gate_correction_bias function.
|
|
||||||
"""
|
|
||||||
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
|
|
||||||
return gate_correction_bias_tensor[0].unsqueeze(0)
|
|
||||||
|
|
||||||
def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict):
|
|
||||||
"""
|
|
||||||
extract_gate_correction_bias function.
|
|
||||||
"""
|
|
||||||
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
|
|
||||||
return gate_correction_bias_tensor[1].unsqueeze(0)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.text_fused_moe.load_state_dict(state_dict)
|
self.text_fused_moe.load_state_dict(state_dict)
|
||||||
self.image_fused_moe.load_state_dict(state_dict)
|
self.image_fused_moe.load_state_dict(state_dict)
|
||||||
if self.text_fused_moe.moe_use_gate_correction_bias:
|
if self.text_fused_moe.experts.moe_use_gate_correction_bias:
|
||||||
state_dict.pop(self.text_fused_moe.gate_correction_bias_key)
|
state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
self.shared_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
@@ -32,6 +32,7 @@ from fastdeploy.model_executor.layers.activation import SiluAndMul
|
|||||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||||
from fastdeploy.model_executor.layers.linear import (
|
from fastdeploy.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
@@ -41,6 +42,47 @@ from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
|||||||
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
|
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3MoeBlock(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fd_config: FDConfig,
|
||||||
|
layer_id: int,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
weight_key_map = {
|
||||||
|
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||||
|
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||||
|
}
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
fd_config,
|
||||||
|
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||||
|
num_experts=fd_config.model_config.num_experts,
|
||||||
|
top_k=fd_config.model_config.num_experts_per_tok,
|
||||||
|
layer_idx=layer_id,
|
||||||
|
weight_key_map=weight_key_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gate = ReplicatedLinear(
|
||||||
|
fd_config=fd_config,
|
||||||
|
prefix=f"{prefix}.gate",
|
||||||
|
input_size=fd_config.model_config.hidden_size,
|
||||||
|
output_size=fd_config.model_config.num_experts,
|
||||||
|
with_bias=False,
|
||||||
|
skip_quant=True,
|
||||||
|
weight_dtype="float32",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.experts(x, self.gate)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
""" """
|
||||||
|
self.gate.load_state_dict(state_dict)
|
||||||
|
self.experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MLP(nn.Layer):
|
class Qwen3MLP(nn.Layer):
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
@@ -104,22 +146,13 @@ class Qwen3DecoderLayer(nn.Layer):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
|
mlp_only_layers = (
|
||||||
weight_key_map = {
|
[] if not hasattr(fd_config.model_config, "mlp_only_layers") else fd_config.model_config.mlp_only_layers
|
||||||
"gate_weight_key": f"{prefix}.mlp.gate.weight",
|
|
||||||
"up_gate_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight",
|
|
||||||
"down_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.down_proj.weight",
|
|
||||||
}
|
|
||||||
|
|
||||||
if fd_config.model_config.num_experts is not None and layer_id >= fd_config.model_config.moe_layer_start_index:
|
|
||||||
self.mlp = FusedMoE(
|
|
||||||
fd_config,
|
|
||||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
|
||||||
num_experts=fd_config.model_config.num_experts,
|
|
||||||
top_k=fd_config.model_config.num_experts_per_tok,
|
|
||||||
layer_idx=layer_id,
|
|
||||||
weight_key_map=weight_key_map,
|
|
||||||
)
|
)
|
||||||
|
if (layer_id not in mlp_only_layers) and (
|
||||||
|
fd_config.model_config.num_experts > 0 and (layer_id + 1) % fd_config.model_config.decoder_sparse_step == 0
|
||||||
|
):
|
||||||
|
self.mlp = Qwen3MoeBlock(fd_config, layer_id, prefix=f"{prefix}.mlp")
|
||||||
else:
|
else:
|
||||||
self.mlp = Qwen3MLP(
|
self.mlp = Qwen3MLP(
|
||||||
fd_config,
|
fd_config,
|
||||||
@@ -279,6 +312,74 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
|||||||
""" """
|
""" """
|
||||||
return "Qwen3MoeForCausalLM"
|
return "Qwen3MoeForCausalLM"
|
||||||
|
|
||||||
|
def get_expert_mapping(
|
||||||
|
self,
|
||||||
|
) -> list[tuple[str, str, int, str]]:
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
param_gate_up_proj_name="experts.up_gate_proj_",
|
||||||
|
param_down_proj_name="experts.down_proj_",
|
||||||
|
num_experts=self.fd_config.model_config.num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def load_weights(self, weights_iterator) -> None:
|
||||||
|
"""
|
||||||
|
Load model parameters from a given weights_iterator object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.models.utils import default_weight_loader
|
||||||
|
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("up_gate_proj", "gate_proj", "gate"),
|
||||||
|
("up_gate_proj", "up_proj", "up"),
|
||||||
|
("embed_tokens.embeddings", "embed_tokens", None),
|
||||||
|
("lm_head.linear", "lm_head", None),
|
||||||
|
]
|
||||||
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in loaded_weight_name:
|
||||||
|
continue
|
||||||
|
if "mlp.experts" in loaded_weight_name:
|
||||||
|
continue
|
||||||
|
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||||
|
if model_param_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[model_param_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in loaded_weight_name:
|
||||||
|
continue
|
||||||
|
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||||
|
if model_param_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[model_param_name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if loaded_weight_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[loaded_weight_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def set_state_dict(self, state_dict):
|
def set_state_dict(self, state_dict):
|
||||||
"""
|
"""
|
||||||
|
@@ -72,7 +72,11 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
|||||||
loaded_weight = loaded_weight[..., shard_offset:shard_size]
|
loaded_weight = loaded_weight[..., shard_offset:shard_size]
|
||||||
else:
|
else:
|
||||||
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
|
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
|
||||||
|
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||||
|
if param.dtype != loaded_weight.dtype:
|
||||||
|
loaded_weight = loaded_weight.cast(param.dtype)
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape, (
|
assert param.shape == loaded_weight.shape, (
|
||||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||||
|
@@ -156,12 +156,12 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
|
|||||||
# Helper function to add layer mappings
|
# Helper function to add layer mappings
|
||||||
def _add_layer_mappings(layer_idx: int):
|
def _add_layer_mappings(layer_idx: int):
|
||||||
# MoE specific mappings
|
# MoE specific mappings
|
||||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = (
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
|
||||||
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.fd_config.model_config.moe_use_aux_free:
|
if self.fd_config.model_config.moe_use_aux_free:
|
||||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = (
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.experts.gate_correction_bias"] = (
|
||||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
|
|||||||
for expert_idx in range(self.fd_config.model_config.moe_num_experts):
|
for expert_idx in range(self.fd_config.model_config.moe_num_experts):
|
||||||
for ph in place_holders:
|
for ph in place_holders:
|
||||||
# up_gate_proj (up_gate_proj)
|
# up_gate_proj (up_gate_proj)
|
||||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.up_gate_proj_weight"
|
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
|
||||||
if up_gate_proj_key not in self.infer_to_train_mapping:
|
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||||
self.infer_to_train_mapping[up_gate_proj_key] = []
|
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||||
self.infer_to_train_mapping[up_gate_proj_key].append(
|
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||||
@@ -177,7 +177,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# down_proj (down_proj)
|
# down_proj (down_proj)
|
||||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.down_proj_weight"
|
down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
|
||||||
if down_proj_key not in self.infer_to_train_mapping:
|
if down_proj_key not in self.infer_to_train_mapping:
|
||||||
self.infer_to_train_mapping[down_proj_key] = []
|
self.infer_to_train_mapping[down_proj_key] = []
|
||||||
self.infer_to_train_mapping[down_proj_key].append(
|
self.infer_to_train_mapping[down_proj_key].append(
|
||||||
@@ -230,13 +230,13 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
|||||||
def _add_expert_mappings(layer_idx: int, moe_tag: str, expert_start: int):
|
def _add_expert_mappings(layer_idx: int, moe_tag: str, expert_start: int):
|
||||||
# MoE specific mappings
|
# MoE specific mappings
|
||||||
gate_suffix = "" if moe_tag == "text" else "_1"
|
gate_suffix = "" if moe_tag == "text" else "_1"
|
||||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = (
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate.weight"] = (
|
||||||
f"{base_name}.{layer_idx}.mlp.gate.weight{gate_suffix}"
|
f"{base_name}.{layer_idx}.mlp.gate.weight{gate_suffix}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.fd_config.model_config.moe_use_aux_free:
|
if self.fd_config.model_config.moe_use_aux_free:
|
||||||
self.infer_to_train_mapping[
|
self.infer_to_train_mapping[
|
||||||
f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias"
|
f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.gate_correction_bias"
|
||||||
] = f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
] = f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||||
|
|
||||||
# Initialize defaultdict for expert weights
|
# Initialize defaultdict for expert weights
|
||||||
@@ -255,12 +255,12 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
|||||||
expert_num_per_rank,
|
expert_num_per_rank,
|
||||||
):
|
):
|
||||||
for ph in place_holders:
|
for ph in place_holders:
|
||||||
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append(
|
expert_mappings[
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.up_gate_proj_weight"
|
||||||
)
|
].append(f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}")
|
||||||
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight"].append(
|
expert_mappings[
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.down_proj_weight"
|
||||||
)
|
].append(f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}")
|
||||||
self.infer_to_train_mapping.update(expert_mappings)
|
self.infer_to_train_mapping.update(expert_mappings)
|
||||||
|
|
||||||
moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
|
moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
|
||||||
@@ -375,12 +375,12 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
|
|||||||
# Helper function to add layer mappings
|
# Helper function to add layer mappings
|
||||||
def _add_layer_mappings(layer_idx: int):
|
def _add_layer_mappings(layer_idx: int):
|
||||||
# MoE specific mappings
|
# MoE specific mappings
|
||||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate_weight"] = (
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
|
||||||
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.fd_config.moe_config.moe_use_aux_free:
|
if self.fd_config.moe_config.moe_use_aux_free:
|
||||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = (
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.experts.gate_correction_bias"] = (
|
||||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -388,7 +388,7 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
|
|||||||
for expert_idx in range(self.fd_config.moe_config.num_experts):
|
for expert_idx in range(self.fd_config.moe_config.num_experts):
|
||||||
for ph in place_holders:
|
for ph in place_holders:
|
||||||
# up_gate_proj (up_gate_proj)
|
# up_gate_proj (up_gate_proj)
|
||||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.up_gate_proj_weight"
|
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
|
||||||
if up_gate_proj_key not in self.infer_to_train_mapping:
|
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||||
self.infer_to_train_mapping[up_gate_proj_key] = []
|
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||||
self.infer_to_train_mapping[up_gate_proj_key].append(
|
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||||
@@ -396,7 +396,7 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# down_proj (down_proj)
|
# down_proj (down_proj)
|
||||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.down_proj_weight"
|
down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
|
||||||
if down_proj_key not in self.infer_to_train_mapping:
|
if down_proj_key not in self.infer_to_train_mapping:
|
||||||
self.infer_to_train_mapping[down_proj_key] = []
|
self.infer_to_train_mapping[down_proj_key] = []
|
||||||
self.infer_to_train_mapping[down_proj_key].append(
|
self.infer_to_train_mapping[down_proj_key].append(
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import difflib
|
|
||||||
|
|
||||||
from paddleformers.trl.llm_utils import init_dist_env
|
from paddleformers.trl.llm_utils import init_dist_env
|
||||||
|
|
||||||
@@ -50,23 +49,35 @@ for k, v in actor_eval_model.get_name_mappings_to_training().items():
|
|||||||
content += f"{k}:{v}\n"
|
content += f"{k}:{v}\n"
|
||||||
|
|
||||||
|
|
||||||
def compare_strings(a: str, b: str) -> bool:
|
def compare_strings_line_by_line(a: str, b: str) -> bool:
|
||||||
if a == b:
|
"""
|
||||||
print("✅ 两个字符串完全一致")
|
Compare two multiline strings line by line.
|
||||||
return True
|
|
||||||
|
|
||||||
print("❌ 字符串不一致,差异如下(上下文差异显示):")
|
Returns:
|
||||||
diff = difflib.ndiff(a.splitlines(), b.splitlines())
|
True if all lines match exactly in order and content.
|
||||||
for line in diff:
|
False if any line differs or the number of lines is not equal.
|
||||||
if line.startswith("- ") or line.startswith("+ "):
|
"""
|
||||||
print(line)
|
a_lines = a.splitlines()
|
||||||
|
b_lines = b.splitlines()
|
||||||
|
|
||||||
|
if len(a_lines) != len(b_lines):
|
||||||
|
print(f"❌ Mismatch in number of lines: expected {len(a_lines)}, but got {len(b_lines)}.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
for i, (line_a, line_b) in enumerate(zip(a_lines, b_lines)):
|
||||||
|
if line_a != line_b:
|
||||||
|
print(f"❌ Difference found on line {i + 1}:")
|
||||||
|
print(f" Expected: {repr(line_a)}")
|
||||||
|
print(f" Actual : {repr(line_b)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✅ All lines match exactly.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
with open("baseline.txt", "r", encoding="utf-8") as f:
|
with open("baseline.txt", "r", encoding="utf-8") as f:
|
||||||
baseline = f.read()
|
baseline = f.read()
|
||||||
assert compare_strings(baseline, content), (
|
assert compare_strings_line_by_line(baseline, content), (
|
||||||
"In the unittest of RL scenario, your modification "
|
"In the unittest of RL scenario, your modification "
|
||||||
"caused inconsistency in the content before and after. Please fix it. "
|
"caused inconsistency in the content before and after. Please fix it. "
|
||||||
"Can request assistance from yuanlehome or gzy19990617 (github id)."
|
"Can request assistance from yuanlehome or gzy19990617 (github id)."
|
||||||
|
Reference in New Issue
Block a user