This commit is contained in:
bukejiyu
2025-08-06 14:45:27 +08:00
committed by GitHub
parent 91dc87f1c5
commit 20839abccf
30 changed files with 1361 additions and 1087 deletions

View File

@@ -663,7 +663,7 @@ class LoadChoices(str, Enum):
DEFAULT = "default" DEFAULT = "default"
# only support qwen3-bf16 now # only support qwen3-bf16 now
NEW_LOADER = "new_loader" DEFAULT_V1 = "default_v1"
class LoadConfig: class LoadConfig:

View File

@@ -22,7 +22,9 @@ import paddle
from paddle import nn from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
UnquantizedFusedMoEMethod,
)
from fastdeploy.model_executor.layers.utils import ( from fastdeploy.model_executor.layers.utils import (
CpuGuard, CpuGuard,
create_and_set_parameter, create_and_set_parameter,
@@ -37,7 +39,7 @@ from fastdeploy.model_executor.ops.gcu import (
) )
class GCUFusedMoeMethod(MoEMethodBase): class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
""" """
Use GCU to compute Fused MoE. Use GCU to compute Fused MoE.
""" """
@@ -46,28 +48,12 @@ class GCUFusedMoeMethod(MoEMethodBase):
super().__init__(quant_config) super().__init__(quant_config)
self.group_size = -1 self.group_size = -1
def create_weights(self, layer: nn.Layer, state_dict): def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle gcu create weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]): layer.up_gate_proj_weight.set_value(paddle.transpose(stacked_up_gate_proj_weights, [0, 2, 1]))
# shape [E, K, N] -> [E, N, K] layer.down_proj_weight.set_value(paddle.transpose(stacked_down_proj_weights, [0, 2, 1]))
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
weight_name = self.added_weight_attrs[idx]
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(weight_tensor)
@paddle.no_grad() @paddle.no_grad()
def compute_ffn( def compute_ffn(
@@ -202,18 +188,19 @@ class GCUFusedMoeMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle gcu compute Fused MoE. Paddle gcu compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
return self.compute_ffn(layer, x, gate_out, enable_quant=False) return self.compute_ffn(layer, x, gate_out, enable_quant=False)
def apply_ep_prefill( def apply_ep_prefill(
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP prefill method. Apply the EP prefill method.
@@ -224,7 +211,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP decoder method. Apply the EP decoder method.
@@ -235,7 +222,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
@@ -400,9 +387,10 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle gcu compute Fused MoE. Paddle gcu compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
return self.compute_ffn(layer, x, gate_out, enable_quant=True) return self.compute_ffn(layer, x, gate_out, enable_quant=True)

View File

@@ -37,7 +37,7 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
self.quant_config = quant_config self.quant_config = quant_config
self.group_size = -1 self.group_size = -1
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization. # The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]] weight_scale_shape = [layer.weight_shape[1]]
@@ -45,6 +45,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
if self.quant_config.name() == "wint4": if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 2 layer.weight_shape[0] //= 2
layer.weight_dtype = "int8" layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter( layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape, shape=weight_scale_shape,
dtype=layer._dtype, dtype=layer._dtype,

View File

@@ -35,7 +35,7 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
) -> None: ) -> None:
super().__init__(quant_config) super().__init__(quant_config)
def create_weights(self, layer: nn.Layer) -> None: def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None:
""" """
Create weights for linear layer on XPU Create weights for linear layer on XPU
""" """
@@ -45,6 +45,12 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
if self.quant_config.name() == "weight_only_int4": if self.quant_config.name() == "weight_only_int4":
layer.weight_shape[0] //= 2 layer.weight_shape[0] //= 2
layer.weight_dtype = "int8" layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter( layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape, shape=weight_scale_shape,
dtype="float32", dtype="float32",

View File

@@ -21,6 +21,7 @@ from paddle import nn
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.models.utils import ( from fastdeploy.model_executor.models.utils import (
default_weight_loader, default_weight_loader,
set_weight_attrs, set_weight_attrs,
@@ -30,6 +31,45 @@ from fastdeploy.platforms import current_platform
from .utils import _set_var_distributed, divide, get_tensor from .utils import _set_var_distributed, divide, get_tensor
class UnquantizedLinearMethod(QuantMethodBase):
"""Linear method without quantization."""
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
extra_weight_attrs is a dictionary that may include parameters like:
- split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
- weight_loader: a callable or method responsible for loading the weight data
"""
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.weight,
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
)
if hasattr(layer, "nranks") and layer.nranks > 0:
split_axis = extra_weight_attrs.get("split_axis")
_set_var_distributed(layer.weight, split_axis=split_axis)
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
def process_loaded_weights(self, layer, weights) -> None:
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if layer.weight.dtype != weights.dtype:
weights = weights.cast(layer.weight.dtype)
layer.weight.set_value(weights)
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
linear_out = paddle.matmul(x, layer.weight)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.bias)
return linear_out
class LinearBase(nn.Layer): class LinearBase(nn.Layer):
""" """
LinearBase Layer. LinearBase Layer.
@@ -44,6 +84,8 @@ class LinearBase(nn.Layer):
with_bias: bool = False, with_bias: bool = False,
add_bias: bool = False, add_bias: bool = False,
skip_quant: bool = False, skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
): ):
""" """
Initializes a linear layer and provides additional parameters required for inference and quantization. Initializes a linear layer and provides additional parameters required for inference and quantization.
@@ -81,6 +123,9 @@ class LinearBase(nn.Layer):
self.add_bias = add_bias self.add_bias = add_bias
self.prefix = prefix self.prefix = prefix
# key # key
if weight_key:
self.weight_key = f"{prefix}.{weight_key}"
else:
self.weight_key = f"{prefix}.weight" self.weight_key = f"{prefix}.weight"
self.bias_key = f"{prefix}.bias" self.bias_key = f"{prefix}.bias"
self.shift_key = f"{prefix}.shift_bias" self.shift_key = f"{prefix}.shift_bias"
@@ -88,39 +133,21 @@ class LinearBase(nn.Layer):
self.out_scale_key = f"{prefix}.out_scale" self.out_scale_key = f"{prefix}.out_scale"
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
if weight_dtype:
self.weight_dtype = weight_dtype
elif self.skip_quant:
self.weight_dtype = self._dtype
else:
self.weight_dtype = self._dtype self.weight_dtype = self._dtype
self.weight_shape = [ self.weight_shape = [
self.input_size, self.input_size,
self.output_size, self.output_size,
] ]
if fd_config.quant_config:
if fd_config.quant_config and not skip_quant:
self.quant_method = fd_config.quant_config.get_quant_method(self) self.quant_method = fd_config.quant_config.get_quant_method(self)
if fd_config.model_config.is_quantized: else:
self.weight_key = f"{prefix}.quant_weight" self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod()
self.weight_scale_key = f"{prefix}.weight_scale"
self.act_scale_key = f"{prefix}.activation_scale"
def init_weight(self):
"""
Initialize the weights and biases.
"""
if self.skip_quant:
self.weight_dtype = self._dtype
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
self.weight,
{
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
)
},
)
self.bias = None self.bias = None
if self.with_bias: if self.with_bias:
@@ -130,19 +157,15 @@ class LinearBase(nn.Layer):
is_bias=True, is_bias=True,
) )
set_weight_attrs(
self.weight,
{
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
)
},
)
# smooth quant # smooth quant
self.linear_shift = None self.linear_shift = None
self.linear_smooth = None self.linear_smooth = None
if fd_config.model_config.is_quantized:
self.weight_key = f"{prefix}.quant_weight"
self.weight_scale_key = f"{prefix}.weight_scale"
self.act_scale_key = f"{prefix}.activation_scale"
def load_prequant_weight(self, state_dict: dict): def load_prequant_weight(self, state_dict: dict):
""" """
Load the prequantized weight from the state dictionary. Load the prequantized weight from the state dictionary.
@@ -160,11 +183,7 @@ class LinearBase(nn.Layer):
state_dict (dict): A dictionary containing the weights state_dict (dict): A dictionary containing the weights
""" """
weight_tensor = get_tensor(state_dict.pop(self.weight_key)) weight_tensor = get_tensor(state_dict.pop(self.weight_key))
if self.fd_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor) self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.weight.set_value(weight_tensor)
def load_state_dict(self, state_dict: dict): def load_state_dict(self, state_dict: dict):
""" """
@@ -199,12 +218,7 @@ class LinearBase(nn.Layer):
Raises: Raises:
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype. NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
""" """
if self.fd_config.quant_config:
linear_out = self.quant_method.apply(self, x) linear_out = self.quant_method.apply(self, x)
else:
linear_out = paddle.matmul(x, self.weight)
if self.with_bias:
linear_out = paddle.add(linear_out, self.bias)
return linear_out return linear_out
@@ -223,6 +237,8 @@ class ReplicatedLinear(LinearBase):
with_bias: bool = False, with_bias: bool = False,
add_bias: bool = False, add_bias: bool = False,
skip_quant: bool = False, skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
): ):
""" """
Initializes a replicated linear layer. Initializes a replicated linear layer.
@@ -245,6 +261,8 @@ class ReplicatedLinear(LinearBase):
with_bias=with_bias, with_bias=with_bias,
add_bias=add_bias, add_bias=add_bias,
skip_quant=skip_quant, skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
) )
self.hidden_size = fd_config.model_config.hidden_size self.hidden_size = fd_config.model_config.hidden_size
@@ -252,9 +270,14 @@ class ReplicatedLinear(LinearBase):
self.input_size, self.input_size,
self.output_size, self.output_size,
] ]
if fd_config.quant_config:
self.quant_method.create_weights(self) assert self.quant_method is not None
self.init_weight() self.quant_method.create_weights(
self,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
)
class ColumnParallelLinear(LinearBase): class ColumnParallelLinear(LinearBase):
@@ -306,60 +329,22 @@ class ColumnParallelLinear(LinearBase):
self.input_size, self.input_size,
self.output_size, self.output_size,
] ]
if fd_config.quant_config:
self.quant_method.create_weights(self)
self.init_weight()
def init_weight(self): assert self.quant_method is not None
""" self.quant_method.create_weights(
Initialize the weights and biases. self,
""" split_axis=1,
if self.skip_quant: output_dim=True,
self.weight_dtype = self._dtype weight_loader=(
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.weight, split_axis=1)
set_weight_attrs(
self.weight,
{
"output_dim": True,
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
), ),
},
) )
self.bias = None
if self.with_bias: if self.with_bias:
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0: if self.nranks > 0:
# col parallel # col parallel
_set_var_distributed(self.bias, split_axis=1) _set_var_distributed(self.bias, split_axis=1)
set_weight_attrs( set_weight_attrs(self.bias, {"output_dim": True})
self.weight,
{
"output_dim": True,
"weight_loader": (
self.weight_loader
if hasattr(self, "weight_loader")
else default_weight_loader(self.fd_config)
),
},
)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
class MergedColumnParallelLinear(ColumnParallelLinear): class MergedColumnParallelLinear(ColumnParallelLinear):
@@ -429,9 +414,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight) loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "gate": if loaded_shard_id == "gate":
param[:, : self.output_size // 2] = loaded_weight param = param[:, : self.output_size // 2]
elif loaded_shard_id == "up": elif loaded_shard_id == "up":
param[:, self.output_size // 2 :] = loaded_weight param = param[:, self.output_size // 2 :]
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
def load_state_dict(self, state_dict: dict): def load_state_dict(self, state_dict: dict):
""" """
@@ -518,16 +508,21 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight) loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "q": if loaded_shard_id == "q":
param[:, : self.num_heads_per_rank * self.head_dim] = loaded_weight param = param[:, : self.num_heads_per_rank * self.head_dim]
elif loaded_shard_id == "k": elif loaded_shard_id == "k":
param[ param = param[
:, :,
self.num_heads_per_rank self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim, * self.head_dim,
] = loaded_weight ]
elif loaded_shard_id == "v": elif loaded_shard_id == "v":
param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :] = loaded_weight param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
def load_weight(self, state_dict: dict): def load_weight(self, state_dict: dict):
""" """
@@ -665,62 +660,25 @@ class RowParallelLinear(LinearBase):
] ]
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
if fd_config.quant_config: assert self.quant_method is not None
self.quant_method = fd_config.quant_config.get_quant_method(self) self.quant_method.create_weights(
self.quant_method.create_weights(self) self,
split_axis=0,
self.reduce_results = reduce_results output_dim=False,
self.init_weight() weight_loader=(
def init_weight(self):
"""
Initialize the weights and biases.
"""
if self.skip_quant:
self.weight_dtype = self._dtype
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.nranks > 0:
# row parallel
set_weight_attrs(
self.weight,
{
"output_dim": False,
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
), ),
},
) )
_set_var_distributed(self.weight, split_axis=0)
self.bias = None
if self.with_bias: if self.with_bias:
self.bias = self.create_parameter( _set_var_distributed(self.bias, split_axis=0)
shape=[self.hidden_size],
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0:
set_weight_attrs( set_weight_attrs(
self.bias, self.bias,
{ {
"output_dim": False, "output_dim": False,
"weight_loader": (
self.weight_loader
if hasattr(self, "weight_loader")
else default_weight_loader(self.fd_config)
),
}, },
) )
self.reduce_results = reduce_results
# smooth quant
self.linear_shift = None
self.linear_smooth = None
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
if self.fd_config.quant_config: if self.fd_config.quant_config:

View File

@@ -19,6 +19,9 @@ from abc import abstractmethod
import paddle import paddle
from paddle import nn from paddle import nn
from fastdeploy.model_executor.models.utils import set_weight_attrs
from fastdeploy.platforms import current_platform
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
@@ -125,7 +128,7 @@ class MoEMethodBase(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP prefill method. Apply the EP prefill method.
@@ -137,7 +140,7 @@ class MoEMethodBase(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP decoder method. Apply the EP decoder method.
@@ -149,7 +152,7 @@ class MoEMethodBase(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
@@ -160,7 +163,7 @@ class MoEMethodBase(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
@@ -168,9 +171,35 @@ class MoEMethodBase(QuantMethodBase):
if layer.ep_size > 1: if layer.ep_size > 1:
if layer.fd_config.parallel_config.moe_phase.phase == "prefill": if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
self.ep_prefill_runner.clean_low_latency_buffer() self.ep_prefill_runner.clean_low_latency_buffer()
return self.apply_ep_prefill(layer, x, gate_out) return self.apply_ep_prefill(layer, x, gate)
else: else:
self.ep_decoder_runner.clean_low_latency_buffer() self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate_out) return self.apply_ep_decode(layer, x, gate)
else: else:
return self.apply_tp(layer, x, gate_out) return self.apply_tp(layer, x, gate)
class UnquantizedFusedMoEMethod(MoEMethodBase):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
if current_platform.is_cuda():
self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2]
self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size]
else:
self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size]
self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size]
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, extra_weight_attrs)

View File

@@ -24,7 +24,7 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduc
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from ..utils import create_and_set_parameter, get_tensor from ..utils import create_and_set_parameter, get_tensor
from .fused_moe_backend_base import MoEMethodBase from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
@@ -64,32 +64,19 @@ def get_moe_scores(
return scores, topk_values, topk_idx return scores, topk_values, topk_idx
class CutlassMoEMethod(MoEMethodBase): class CutlassMoEMethod(UnquantizedFusedMoEMethod):
""" """
Use Cutlass Group Gemm to compute Fused MoE. Use Cutlass Group Gemm to compute Fused MoE.
This method is the oldest way to compute MoE in Paddle. This method is the oldest way to compute MoE in Paddle.
""" """
def create_weights(self, layer: nn.Layer, state_dict): def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass create weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
weight_name = self.added_weight_attrs[idx] layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
setattr( layer.down_proj_weight.set_value(stacked_down_proj_weights)
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(weight_tensor)
def compute_ffn( def compute_ffn(
self, self,
@@ -134,11 +121,12 @@ class CutlassMoEMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP prefill method. Apply the EP prefill method.
""" """
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights # 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. EP Dispatch # 2. EP Dispatch
@@ -206,11 +194,12 @@ class CutlassMoEMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP decoder method. Apply the EP decoder method.
""" """
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights # 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None) expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
@@ -242,11 +231,12 @@ class CutlassMoEMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
gate_out, _, _ = get_moe_scores( gate_out, _, _ = get_moe_scores(
gate_out, gate_out,

View File

@@ -126,11 +126,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP prefill method. Apply the EP prefill method.
""" """
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights # 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. Dynamic compute blockwise quantization scales # 2. Dynamic compute blockwise quantization scales
@@ -233,11 +234,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP decoder method. Apply the EP decoder method.
""" """
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights # 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
# 2. EP Dispatch # 2. EP Dispatch
@@ -303,13 +305,13 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle Use DeepGemm compute Fused MoE. Paddle Use DeepGemm compute Fused MoE.
below is TP compute method. below is TP compute method.
""" """
gate_out = gate(x.cast("float32"))
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out, gate_out,
layer.gate_correction_bias, layer.gate_correction_bias,

View File

@@ -219,11 +219,12 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Marlin compute Fused MoE. Marlin compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
token_num = x.shape[0] token_num = x.shape[0]
top_k = layer.top_k top_k = layer.top_k
top_k = layer.top_k top_k = layer.top_k

View File

@@ -115,11 +115,12 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
token_num = x.shape[0] token_num = x.shape[0]
top_k = layer.top_k top_k = layer.top_k
num_local_experts = layer.num_local_experts num_local_experts = layer.num_local_experts
@@ -336,12 +337,12 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
token_num = x.shape[0] token_num = x.shape[0]
top_k = layer.top_k top_k = layer.top_k
num_local_experts = layer.num_local_experts num_local_experts = layer.num_local_experts
@@ -576,12 +577,12 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
token_num = x.shape[0] token_num = x.shape[0]
top_k = layer.top_k top_k = layer.top_k
num_local_experts = layer.num_local_experts num_local_experts = layer.num_local_experts

View File

@@ -171,12 +171,12 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Use Wint2 Triton Fusedmoe compute Fused MoE. Use Wint2 Triton Fusedmoe compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
( (
@@ -242,12 +242,12 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Use Wint2 Triton Fusedmoe compute Fused MoE. Use Wint2 Triton Fusedmoe compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
from fastdeploy.model_executor.ops.triton_ops import moe_wint2_ffn_kernel from fastdeploy.model_executor.ops.triton_ops import moe_wint2_ffn_kernel
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(

View File

@@ -19,47 +19,36 @@ from typing import Dict
import paddle import paddle
from paddle import nn from paddle import nn
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
UnquantizedFusedMoEMethod,
)
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
from .fused_moe_backend_base import MoEMethodBase
class XPUMoEMethod(UnquantizedFusedMoEMethod):
class XPUMoEMethod(MoEMethodBase):
""" """
XPU MOE XPU MOE
""" """
def create_weights(self, layer: nn.Layer, state_dict): def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass create weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
for weights in [up_gate_proj_weights, down_proj_weights]: for weights in [up_gate_proj_weights, down_proj_weights]:
for idx, weight in enumerate(weights): for idx, weight in enumerate(weights):
weights[idx] = weight.transpose([1, 0]) weights[idx] = weight.transpose([1, 0])
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0) stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0) stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
weight_name = self.added_weight_attrs[idx] layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
setattr( layer.down_proj_weight.set_value(stacked_down_proj_weights)
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(weight_tensor)
def apply_tp( def apply_tp(
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
@@ -68,7 +57,7 @@ class XPUMoEMethod(MoEMethodBase):
fused_moe_out = xpu_moe_layer( fused_moe_out = xpu_moe_layer(
x, x,
layer.gate_weight.transpose([1, 0]), gate.weight.transpose([1, 0]),
layer.gate_correction_bias, layer.gate_correction_bias,
layer.up_gate_proj_weight, layer.up_gate_proj_weight,
layer.down_proj_weight, layer.down_proj_weight,
@@ -94,7 +83,7 @@ class XPUMoEMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP prefill method. Apply the EP prefill method.
@@ -105,7 +94,7 @@ class XPUMoEMethod(MoEMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Apply the EP decoder method. Apply the EP decoder method.
@@ -187,7 +176,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
XPU compute Fused MoE. XPU compute Fused MoE.
@@ -196,7 +185,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
fused_moe_out = xpu_moe_layer( fused_moe_out = xpu_moe_layer(
x, x,
layer.gate_weight.transpose([1, 0]), gate.weight.transpose([1, 0]),
layer.gate_correction_bias, layer.gate_correction_bias,
layer.up_gate_proj_weight, layer.up_gate_proj_weight,
layer.down_proj_weight, layer.down_proj_weight,

View File

@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Optional
import paddle import paddle
from paddle import nn from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
@@ -77,7 +79,7 @@ class FusedMoE(nn.Layer):
self.fd_config = fd_config self.fd_config = fd_config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_size = fd_config.parallel_config.tensor_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.ep_size = fd_config.parallel_config.expert_parallel_size self.ep_size = fd_config.parallel_config.expert_parallel_size
self.ep_rank = fd_config.parallel_config.expert_parallel_rank self.ep_rank = fd_config.parallel_config.expert_parallel_rank
@@ -109,14 +111,19 @@ class FusedMoE(nn.Layer):
self.n_group = n_group self.n_group = n_group
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
moe_quant_config = fd_config.quant_config moe_quant_config = fd_config.quant_config
self.moe_quant_config = moe_quant_config
self.moe_quant_type = None self.moe_quant_type = None
if moe_quant_config: if moe_quant_config:
self.quant_method = moe_quant_config.get_quant_method(self) self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name() self.moe_quant_type = moe_quant_config.name()
else: else:
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future # w_fp16 a_fp16
self.quant_method = get_moe_method() self.quant_method = get_moe_method()
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
self.redundant_table_manger = None self.redundant_table_manger = None
if self.ep_size > 1: if self.ep_size > 1:
@@ -140,21 +147,121 @@ class FusedMoE(nn.Layer):
tp_size={self.tp_size}." tp_size={self.tp_size}."
) )
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
from fastdeploy.platforms import current_platform
if shard_id is None:
# 1.gate up fused in disk
return
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
expert_param = param[expert_id]
if current_platform.is_cuda():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
self._load_expert_weight(
expert_param=expert_param,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
loaded_weight=loaded_weight,
shard_id=shard_id,
)
def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
tensor_size = expert_param.shape[shard_dim] // 2
if shard_id == "gate":
expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...]
elif shard_id == "up":
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
if self.tp_size > 1:
size = loaded_weight.get_shape()[-1]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
expert_param.copy_(loaded_weight, False)
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
if self.tp_size > 1:
size = loaded_weight.get_shape()[shard_dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
loaded_weight = get_tensor(loaded_weight)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
expert_param.copy_(loaded_weight, False)
def _load_expert_weight(
self,
expert_param,
shard_dim,
loaded_weight,
shard_id,
):
if shard_id == "down":
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
elif shard_id in ["gate", "up"]:
self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id)
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
param_gate_up_proj_name: str,
param_down_proj_name: str,
num_experts: int,
ckpt_expert_key_name: str = "experts",
ckpt_gate_up_proj_name: Optional[str] = None,
) -> list[tuple[str, str, int, str]]:
param_name_maping = [
("gate", ckpt_gate_proj_name),
("down", ckpt_down_proj_name),
("up", ckpt_up_proj_name),
]
if ckpt_gate_up_proj_name:
param_name_maping.append((None, ckpt_gate_up_proj_name))
return [
# (param_name, weight_name, expert_id, shard_id)
(
(
param_gate_up_proj_name
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else param_down_proj_name
),
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in param_name_maping
]
def init_moe_weights(self): def init_moe_weights(self):
""" """
Initialize the weight shapes and parameters for the MoE layer. Initialize the weight shapes and parameters for the MoE layer.
Combines weight shape initialization and parameter creation into a single function. Combines weight shape initialization and parameter creation into a single function.
""" """
# Initialize weight shapes # Initialize weight shapes
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
gate_weight_shape = [self.hidden_size, self.num_experts]
gate_correction_bias_shape = [1, self.num_experts] gate_correction_bias_shape = [1, self.num_experts]
self.gate_weight = self.create_parameter(
shape=gate_weight_shape,
dtype="float32",
)
if self.fd_config.model_config.moe_use_aux_free: if self.fd_config.model_config.moe_use_aux_free:
self.gate_correction_bias = self.create_parameter( self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_shape, shape=gate_correction_bias_shape,
@@ -374,26 +481,19 @@ class FusedMoE(nn.Layer):
) )
self.gate_correction_bias.set_value(gate_correction_bias_tensor) self.gate_correction_bias.set_value(gate_correction_bias_tensor)
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
self.gate_weight = self.create_parameter(
shape=gate_weight_tensor.shape,
dtype="float32",
)
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
if self.fd_config.model_config.is_quantized: if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True): if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict) self.quant_method.process_prequanted_weights(self, state_dict)
else: else:
self.quant_method.create_weights(self, state_dict) self.quant_method.create_weights(self, state_dict)
else: else:
if self.moe_quant_config:
self.quant_method.create_weights(self, state_dict) self.quant_method.create_weights(self, state_dict)
else:
# w_fp16 a_fp16
self.quant_method.process_loaded_weights(self, state_dict)
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor, gate: nn.Layer):
""" """
Defines the forward computation of the moe layer. Defines the forward computation of the moe layer.
@@ -404,6 +504,5 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s Tensor: Output tensor.s
""" """
gate_out = paddle.matmul(x.cast("float32"), self.gate_weight) out = self.quant_method.apply(self, x, gate)
out = self.quant_method.apply(self, x, gate_out)
return out return out

View File

@@ -81,8 +81,16 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
super().__init__() super().__init__()
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse() layer.weight_shape.reverse()
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter( layer.weight_scale = layer.create_parameter(
shape=[ shape=[
(layer.output_size + self.quant_config.weight_block_size[0] - 1) (layer.output_size + self.quant_config.weight_block_size[0] - 1)

View File

@@ -16,6 +16,8 @@
from typing import Optional from typing import Optional
import paddle
from fastdeploy.model_executor.layers.moe import FusedMoE from fastdeploy.model_executor.layers.moe import FusedMoE
from ..utils import get_tensor from ..utils import get_tensor
@@ -79,11 +81,14 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
self.quant_round_type = 1 self.quant_round_type = 1
self.weight_dtype = "float8_e4m3fn" self.weight_dtype = "float8_e4m3fn"
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
"""
Nothing to do! layer.weight = layer.create_parameter(
""" shape=layer.weight_shape,
pass dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_prequanted_weights(self, layer, state_dict) -> None: def process_prequanted_weights(self, layer, state_dict) -> None:
""" """

View File

@@ -63,11 +63,17 @@ class W4AFP8LinearMethod(QuantMethodBase):
super().__init__() super().__init__()
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse() layer.weight_shape.reverse()
layer.weight_shape[0] //= 2 layer.weight_shape[0] //= 2
layer.weight_dtype = "int8" layer.weight_dtype = "int8"
pass
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:
( (

View File

@@ -74,7 +74,7 @@ class W8A8LinearMethod(QuantMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config) self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse() layer.weight_shape.reverse()
layer.weight_dtype = "int8" layer.weight_dtype = "int8"
if self.quant_config.use_smooth_quant: if self.quant_config.use_smooth_quant:
@@ -85,7 +85,12 @@ class W8A8LinearMethod(QuantMethodBase):
if weight_scale is None or in_scale is None: if weight_scale is None or in_scale is None:
self.skip_quant = True self.skip_quant = True
return return
layer.wieght = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
max_range = 127.0 max_range = 127.0
linear_out_scale = paddle.to_tensor(weight_scale / (max_range * max_range * in_scale)).astype("float32") linear_out_scale = paddle.to_tensor(weight_scale / (max_range * max_range * in_scale)).astype("float32")
layer.linear_out_scale = layer.create_parameter( layer.linear_out_scale = layer.create_parameter(
@@ -136,7 +141,7 @@ class SmoothQuantLinearMethod(QuantMethodBase):
super().__init__() super().__init__()
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
linear_shift_shape = [layer.output_size] linear_shift_shape = [layer.output_size]
linear_smooth_shape = [layer.output_size] linear_smooth_shape = [layer.output_size]
layer.linear_shift = self.create_parameter( layer.linear_shift = self.create_parameter(

View File

@@ -168,7 +168,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
super().__init__() super().__init__()
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization. # The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]] weight_scale_shape = [layer.weight_shape[1]]
@@ -177,6 +177,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
if self.quant_config.name() == "wint4": if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 2 layer.weight_shape[0] //= 2
layer.weight_dtype = "int8" layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter( layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape, shape=weight_scale_shape,
dtype=layer._dtype, dtype=layer._dtype,

View File

@@ -69,12 +69,18 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
super().__init__() super().__init__()
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer): def create_weights(self, layer, **extra_weight_attrs):
""" """ """ """
layer.weight_shape.reverse() layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn" layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func # TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False self.skip_quant = False
layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter( layer.weight_scale = layer.create_parameter(
shape=[1], shape=[1],
dtype="float32", dtype="float32",

View File

@@ -17,14 +17,16 @@
from fastdeploy.config import LoadChoices, LoadConfig from fastdeploy.config import LoadChoices, LoadConfig
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.model_loader.default_loader import DefaultModelLoader from fastdeploy.model_executor.model_loader.default_loader import DefaultModelLoader
from fastdeploy.model_executor.model_loader.new_loader import NewModelLoader from fastdeploy.model_executor.model_loader.default_loader_v1 import (
DefaultModelLoaderV1,
)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""get_model_loader""" """get_model_loader"""
if load_config.load_choices == LoadChoices.NEW_LOADER: if load_config.load_choices == LoadChoices.DEFAULT_V1:
return NewModelLoader(load_config) return DefaultModelLoaderV1(load_config)
return DefaultModelLoader(load_config) return DefaultModelLoader(load_config)

View File

@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" """
import contextlib
import paddle import paddle
from paddle import nn from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
@@ -62,15 +64,16 @@ class DefaultModelLoader(BaseModelLoader):
self.clean_memory_fragments(state_dict) self.clean_memory_fragments(state_dict)
def load_model(self, fd_config: FDConfig) -> nn.Layer: def load_model(self, fd_config: FDConfig) -> nn.Layer:
context = paddle.LazyGuard()
architectures = fd_config.model_config.architectures[0] architectures = fd_config.model_config.architectures[0]
logger.info(f"Starting to load model {architectures}") logger.info(f"Starting to load model {architectures}")
if fd_config.load_config.dynamic_load_weight: if fd_config.load_config.dynamic_load_weight:
# register rl model # register rl model
import fastdeploy.rl # noqa import fastdeploy.rl # noqa
architectures = architectures + "RL" architectures = architectures + "RL"
context = paddle.LazyGuard()
else:
context = contextlib.nullcontext()
with context: with context:
model_cls = ModelRegistry.get_class(architectures) model_cls = ModelRegistry.get_class(architectures)

View File

@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" """
import contextlib
import paddle import paddle
from paddle import nn from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
@@ -29,7 +31,7 @@ from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
class NewModelLoader(BaseModelLoader): class DefaultModelLoaderV1(BaseModelLoader):
"""ModelLoader that can load registered models""" """ModelLoader that can load registered models"""
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
@@ -54,13 +56,17 @@ class NewModelLoader(BaseModelLoader):
def load_model(self, fd_config: FDConfig) -> nn.Layer: def load_model(self, fd_config: FDConfig) -> nn.Layer:
architectures = fd_config.model_config.architectures[0] architectures = fd_config.model_config.architectures[0]
logger.info(f"Starting to load model {architectures}") logger.info(f"Starting to load model {architectures}")
if fd_config.load_config.dynamic_load_weight: if fd_config.load_config.dynamic_load_weight:
# register rl model # register rl model
import fastdeploy.rl # noqa import fastdeploy.rl # noqa
architectures = architectures + "RL" architectures = architectures + "RL"
context = paddle.LazyGuard()
else:
context = contextlib.nullcontext()
with context:
model_cls = ModelRegistry.get_class(architectures) model_cls = ModelRegistry.get_class(architectures)
model = model_cls(fd_config) model = model_cls(fd_config)

View File

@@ -117,13 +117,12 @@ class DeepSeekV3MoE(nn.Layer):
self.tp_size = fd_config.parallel_config.tensor_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size
weight_key_map = { weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias", "gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
} }
self.fused_moe = FusedMoE( self.experts = FusedMoE(
fd_config=fd_config, fd_config=fd_config,
reduce_results=False, reduce_results=False,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size, moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
@@ -137,6 +136,16 @@ class DeepSeekV3MoE(nn.Layer):
weight_key_map=weight_key_map, weight_key_map=weight_key_map,
) )
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.num_shared_experts = fd_config.model_config.n_shared_experts self.num_shared_experts = fd_config.model_config.n_shared_experts
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
@@ -149,13 +158,14 @@ class DeepSeekV3MoE(nn.Layer):
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """ """ """
self.fused_moe.load_state_dict(state_dict) self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict) self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor): def forward(self, hidden_states: paddle.Tensor):
""" """ """ """
shared_experts_out = self.shared_experts(hidden_states) shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.fused_moe(hidden_states) moe_out = self.experts(hidden_states, self.gate)
moe_out = moe_out + shared_experts_out moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts. # We do to TP all reduce after the sum of experts.
if self.tp_size > 1: if self.tp_size > 1:

View File

@@ -37,6 +37,7 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ( from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
@@ -147,7 +148,7 @@ class Ernie4_5_MoE(nn.Layer):
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
} }
self.fused_moe = FusedMoE( self.experts = FusedMoE(
fd_config=fd_config, fd_config=fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size, moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.moe_num_experts, num_experts=fd_config.model_config.moe_num_experts,
@@ -156,6 +157,16 @@ class Ernie4_5_MoE(nn.Layer):
weight_key_map=weight_key_map, weight_key_map=weight_key_map,
) )
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.moe_num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
@@ -166,12 +177,13 @@ class Ernie4_5_MoE(nn.Layer):
) )
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict) self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict) self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor): def forward(self, hidden_states: paddle.Tensor):
out = self.fused_moe(hidden_states) out = self.experts(hidden_states, self.gate)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states) s_x = self.shared_experts(hidden_states)
out = out + s_x out = out + s_x
@@ -435,7 +447,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states) self.ernie.layers[i].mlp.expert(fake_hidden_states)
def forward( def forward(
self, self,

View File

@@ -33,6 +33,7 @@ from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization, support_graph_optimization,
) )
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ReplicatedLinear
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.layers.normalization import RMSNorm
@@ -73,6 +74,93 @@ class VLMoEMeta:
fake_hidden_states: Optional[paddle.Tensor] = None fake_hidden_states: Optional[paddle.Tensor] = None
class Ernie4_5_VLMoeBlock(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str, moe_tag: str, expert_id_offset: int) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
if moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
):
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
# wint4/wint8/bfloat16
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
moe_intermediate_size = (
fd_config.model_config.moe_intermediate_size[0]
if moe_tag == "Text"
else fd_config.model_config.moe_intermediate_size[1]
)
num_experts = (
fd_config.model_config.moe_num_experts[0]
if moe_tag == "Text"
else fd_config.model_config.moe_num_experts[1]
)
self.experts = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=moe_intermediate_size,
num_experts=num_experts,
expert_id_offset=expert_id_offset,
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
moe_tag=moe_tag,
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
weight_key="weight" if moe_tag == "Text" else "weight_1",
)
if moe_tag == "Text":
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_text
elif moe_tag == "Image":
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_image
def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
return out
def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[0].unsqueeze(0)
def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[1].unsqueeze(0)
def load_state_dict(self, state_dict):
self.experts.load_state_dict(state_dict)
self.gate.load_state_dict(state_dict)
class Ernie4_5_VLMoE(nn.Layer): class Ernie4_5_VLMoE(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None: def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
super().__init__() super().__init__()
@@ -99,43 +187,10 @@ class Ernie4_5_VLMoE(nn.Layer):
assert text_moe_layer_start_index <= text_moe_layer_end_index assert text_moe_layer_start_index <= text_moe_layer_end_index
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index: if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
if moe_quant_type == "tensor_wise_fp8" or ( self.text_fused_moe = Ernie4_5_VLMoeBlock(
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}", moe_tag="Text", expert_id_offset=0
):
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.text_fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size[0],
num_experts=fd_config.model_config.moe_num_experts[0],
expert_id_offset=0,
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
moe_tag="Text",
weight_key_map=weight_key_map,
) )
self.text_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_text
else: else:
self.text_fused_moe = Ernie4_5_VLMLP( self.text_fused_moe = Ernie4_5_VLMLP(
fd_config=fd_config, fd_config=fd_config,
@@ -146,38 +201,13 @@ class Ernie4_5_VLMoE(nn.Layer):
assert image_moe_layer_start_index <= image_moe_layer_end_index assert image_moe_layer_start_index <= image_moe_layer_end_index
if layer_id >= image_moe_layer_start_index and layer_id <= image_moe_layer_end_index: if layer_id >= image_moe_layer_start_index and layer_id <= image_moe_layer_end_index:
if moe_quant_type == "tensor_wise_fp8" or ( self.image_fused_moe = Ernie4_5_VLMoeBlock(
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
):
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight_1",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight_1",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.image_fused_moe = FusedMoE(
fd_config=fd_config, fd_config=fd_config,
reduce_results=False, layer_id=layer_id,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size[1], prefix=f"{prefix}",
num_experts=fd_config.model_config.moe_num_experts[1],
expert_id_offset=fd_config.model_config.moe_num_experts[0],
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
moe_tag="Image", moe_tag="Image",
weight_key_map=weight_key_map, expert_id_offset=fd_config.model_config.moe_num_experts[0],
) )
self.image_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_image
else: else:
self.image_fused_moe = Ernie4_5_VLMLP( self.image_fused_moe = Ernie4_5_VLMLP(
fd_config=fd_config, fd_config=fd_config,
@@ -195,25 +225,11 @@ class Ernie4_5_VLMoE(nn.Layer):
reduce_results=False, reduce_results=False,
) )
def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[0].unsqueeze(0)
def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[1].unsqueeze(0)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self.text_fused_moe.load_state_dict(state_dict) self.text_fused_moe.load_state_dict(state_dict)
self.image_fused_moe.load_state_dict(state_dict) self.image_fused_moe.load_state_dict(state_dict)
if self.text_fused_moe.moe_use_gate_correction_bias: if self.text_fused_moe.experts.moe_use_gate_correction_bias:
state_dict.pop(self.text_fused_moe.gate_correction_bias_key) state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict) self.shared_experts.load_state_dict(state_dict)

View File

@@ -32,6 +32,7 @@ from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ( from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
@@ -41,6 +42,47 @@ from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
class Qwen3MoeBlock(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
layer_id: int,
prefix: str = "",
) -> None:
super().__init__()
weight_key_map = {
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.experts = FusedMoE(
fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.num_experts,
top_k=fd_config.model_config.num_experts_per_tok,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
def forward(self, x):
out = self.experts(x, self.gate)
return out
def load_state_dict(self, state_dict):
""" """
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
class Qwen3MLP(nn.Layer): class Qwen3MLP(nn.Layer):
""" """ """ """
@@ -104,22 +146,13 @@ class Qwen3DecoderLayer(nn.Layer):
layer_id=layer_id, layer_id=layer_id,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
mlp_only_layers = (
weight_key_map = { [] if not hasattr(fd_config.model_config, "mlp_only_layers") else fd_config.model_config.mlp_only_layers
"gate_weight_key": f"{prefix}.mlp.gate.weight",
"up_gate_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.down_proj.weight",
}
if fd_config.model_config.num_experts is not None and layer_id >= fd_config.model_config.moe_layer_start_index:
self.mlp = FusedMoE(
fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.num_experts,
top_k=fd_config.model_config.num_experts_per_tok,
layer_idx=layer_id,
weight_key_map=weight_key_map,
) )
if (layer_id not in mlp_only_layers) and (
fd_config.model_config.num_experts > 0 and (layer_id + 1) % fd_config.model_config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoeBlock(fd_config, layer_id, prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen3MLP( self.mlp = Qwen3MLP(
fd_config, fd_config,
@@ -279,6 +312,74 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
""" """ """ """
return "Qwen3MoeForCausalLM" return "Qwen3MoeForCausalLM"
def get_expert_mapping(
self,
) -> list[tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
num_experts=self.fd_config.model_config.num_experts,
)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
if "mlp.experts" in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
break
else:
if loaded_weight_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
@paddle.no_grad() @paddle.no_grad()
def set_state_dict(self, state_dict): def set_state_dict(self, state_dict):
""" """

View File

@@ -72,7 +72,11 @@ def default_weight_loader(fd_config: FDConfig) -> None:
loaded_weight = loaded_weight[..., shard_offset:shard_size] loaded_weight = loaded_weight[..., shard_offset:shard_size]
else: else:
loaded_weight = loaded_weight[shard_offset:shard_size, ...] loaded_weight = loaded_weight[shard_offset:shard_size, ...]
loaded_weight = get_tensor(loaded_weight) loaded_weight = get_tensor(loaded_weight)
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if param.dtype != loaded_weight.dtype:
loaded_weight = loaded_weight.cast(param.dtype)
assert param.shape == loaded_weight.shape, ( assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"

View File

@@ -156,12 +156,12 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
# Helper function to add layer mappings # Helper function to add layer mappings
def _add_layer_mappings(layer_idx: int): def _add_layer_mappings(layer_idx: int):
# MoE specific mappings # MoE specific mappings
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = ( self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
f"{base_name}.{layer_idx}.mlp.gate.weight" f"{base_name}.{layer_idx}.mlp.gate.weight"
) )
if self.fd_config.model_config.moe_use_aux_free: if self.fd_config.model_config.moe_use_aux_free:
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = ( self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.experts.gate_correction_bias"] = (
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
) )
@@ -169,7 +169,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
for expert_idx in range(self.fd_config.model_config.moe_num_experts): for expert_idx in range(self.fd_config.model_config.moe_num_experts):
for ph in place_holders: for ph in place_holders:
# up_gate_proj (up_gate_proj) # up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.up_gate_proj_weight" up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
if up_gate_proj_key not in self.infer_to_train_mapping: if up_gate_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[up_gate_proj_key] = [] self.infer_to_train_mapping[up_gate_proj_key] = []
self.infer_to_train_mapping[up_gate_proj_key].append( self.infer_to_train_mapping[up_gate_proj_key].append(
@@ -177,7 +177,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
) )
# down_proj (down_proj) # down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.down_proj_weight" down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
if down_proj_key not in self.infer_to_train_mapping: if down_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[down_proj_key] = [] self.infer_to_train_mapping[down_proj_key] = []
self.infer_to_train_mapping[down_proj_key].append( self.infer_to_train_mapping[down_proj_key].append(
@@ -230,13 +230,13 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
def _add_expert_mappings(layer_idx: int, moe_tag: str, expert_start: int): def _add_expert_mappings(layer_idx: int, moe_tag: str, expert_start: int):
# MoE specific mappings # MoE specific mappings
gate_suffix = "" if moe_tag == "text" else "_1" gate_suffix = "" if moe_tag == "text" else "_1"
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = ( self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate.weight"] = (
f"{base_name}.{layer_idx}.mlp.gate.weight{gate_suffix}" f"{base_name}.{layer_idx}.mlp.gate.weight{gate_suffix}"
) )
if self.fd_config.model_config.moe_use_aux_free: if self.fd_config.model_config.moe_use_aux_free:
self.infer_to_train_mapping[ self.infer_to_train_mapping[
f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias" f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.gate_correction_bias"
] = f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" ] = f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
# Initialize defaultdict for expert weights # Initialize defaultdict for expert weights
@@ -255,12 +255,12 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
expert_num_per_rank, expert_num_per_rank,
): ):
for ph in place_holders: for ph in place_holders:
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append( expert_mappings[
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.up_gate_proj_weight"
) ].append(f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}")
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight"].append( expert_mappings[
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.experts.down_proj_weight"
) ].append(f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}")
self.infer_to_train_mapping.update(expert_mappings) self.infer_to_train_mapping.update(expert_mappings)
moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
@@ -375,12 +375,12 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
# Helper function to add layer mappings # Helper function to add layer mappings
def _add_layer_mappings(layer_idx: int): def _add_layer_mappings(layer_idx: int):
# MoE specific mappings # MoE specific mappings
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate_weight"] = ( self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
f"{base_name}.{layer_idx}.mlp.gate.weight" f"{base_name}.{layer_idx}.mlp.gate.weight"
) )
if self.fd_config.moe_config.moe_use_aux_free: if self.fd_config.moe_config.moe_use_aux_free:
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = ( self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.experts.gate_correction_bias"] = (
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
) )
@@ -388,7 +388,7 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
for expert_idx in range(self.fd_config.moe_config.num_experts): for expert_idx in range(self.fd_config.moe_config.num_experts):
for ph in place_holders: for ph in place_holders:
# up_gate_proj (up_gate_proj) # up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.up_gate_proj_weight" up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
if up_gate_proj_key not in self.infer_to_train_mapping: if up_gate_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[up_gate_proj_key] = [] self.infer_to_train_mapping[up_gate_proj_key] = []
self.infer_to_train_mapping[up_gate_proj_key].append( self.infer_to_train_mapping[up_gate_proj_key].append(
@@ -396,7 +396,7 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
) )
# down_proj (down_proj) # down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.down_proj_weight" down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
if down_proj_key not in self.infer_to_train_mapping: if down_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[down_proj_key] = [] self.infer_to_train_mapping[down_proj_key] = []
self.infer_to_train_mapping[down_proj_key].append( self.infer_to_train_mapping[down_proj_key].append(

File diff suppressed because it is too large Load Diff

View File

@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import difflib
from paddleformers.trl.llm_utils import init_dist_env from paddleformers.trl.llm_utils import init_dist_env
@@ -50,23 +49,35 @@ for k, v in actor_eval_model.get_name_mappings_to_training().items():
content += f"{k}:{v}\n" content += f"{k}:{v}\n"
def compare_strings(a: str, b: str) -> bool: def compare_strings_line_by_line(a: str, b: str) -> bool:
if a == b: """
print("✅ 两个字符串完全一致") Compare two multiline strings line by line.
return True
print("❌ 字符串不一致,差异如下(上下文差异显示):") Returns:
diff = difflib.ndiff(a.splitlines(), b.splitlines()) True if all lines match exactly in order and content.
for line in diff: False if any line differs or the number of lines is not equal.
if line.startswith("- ") or line.startswith("+ "): """
print(line) a_lines = a.splitlines()
b_lines = b.splitlines()
if len(a_lines) != len(b_lines):
print(f"❌ Mismatch in number of lines: expected {len(a_lines)}, but got {len(b_lines)}.")
return False return False
for i, (line_a, line_b) in enumerate(zip(a_lines, b_lines)):
if line_a != line_b:
print(f"❌ Difference found on line {i + 1}:")
print(f" Expected: {repr(line_a)}")
print(f" Actual : {repr(line_b)}")
return False
print("✅ All lines match exactly.")
return True
with open("baseline.txt", "r", encoding="utf-8") as f: with open("baseline.txt", "r", encoding="utf-8") as f:
baseline = f.read() baseline = f.read()
assert compare_strings(baseline, content), ( assert compare_strings_line_by_line(baseline, content), (
"In the unittest of RL scenario, your modification " "In the unittest of RL scenario, your modification "
"caused inconsistency in the content before and after. Please fix it. " "caused inconsistency in the content before and after. Please fix it. "
"Can request assistance from yuanlehome or gzy19990617 (github id)." "Can request assistance from yuanlehome or gzy19990617 (github id)."