[v1 loader]qwen Offline fp8 (#4036)

* support offline fp8

* update ut

* update ut

* update ut

* fix

* update

* update
This commit is contained in:
bukejiyu
2025-09-15 13:44:11 +08:00
committed by GitHub
parent b1a5b756a3
commit 29ed617f0f
21 changed files with 440 additions and 138 deletions

View File

@@ -57,7 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
{
**extra_weight_attrs,
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
},
)
@@ -341,10 +341,10 @@ class MergedReplicatedLinear(ReplicatedLinear):
self.output_sizes = output_sizes
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
weight_need_transpose = getattr(param, "weight_need_transpose", False)
loaded_weight = get_tensor(loaded_weight)
if model_format == "torch":
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
assert loaded_shard_id in ["q_a", "kv_a"]
@@ -365,6 +365,12 @@ class MergedReplicatedLinear(ReplicatedLinear):
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
@@ -483,15 +489,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
weight_need_transpose = getattr(param, "weight_need_transpose", False)
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
shard_dim = -1 if output_dim else 0
output_size = param.shape[shard_dim]
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
param.weight_need_transpose = False
# Loaded weight is already fused on disk.
shard_offsets = [
# (shard_id, shard_offset, shard_size)
@@ -506,6 +513,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
else:
# split gate up
assert loaded_shard_id in ["gate", "up"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
dim = -1 if output_dim else 0
@@ -532,6 +542,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
def load_state_dict(self, state_dict: dict):
@@ -604,11 +620,11 @@ class QKVParallelLinear(ColumnParallelLinear):
add_bias=add_bias,
)
def _get_shard_size_mapping(self, loaded_shard_id: str):
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
shard_size_mapping = {
"q": self.num_heads_per_rank * self.head_dim,
"k": self.kv_num_heads_per_rank * self.head_dim,
"v": self.kv_num_heads_per_rank * self.head_dim,
"q": self.num_heads_per_rank * head_dim,
"k": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * head_dim,
}
return shard_size_mapping.get(loaded_shard_id)
@@ -617,11 +633,12 @@ class QKVParallelLinear(ColumnParallelLinear):
assert output_dim is not None
dim = -1 if output_dim else 0
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
if loaded_shard_id is None:
param.weight_need_transpose = False
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
@@ -637,13 +654,16 @@ class QKVParallelLinear(ColumnParallelLinear):
else:
# split q k v
assert loaded_shard_id in ["q", "k", "v"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
block_size = self._get_shard_size_mapping(loaded_shard_id)
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = (shard_id + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
loaded_weight = get_tensor(loaded_weight)
@@ -663,10 +683,17 @@ class QKVParallelLinear(ColumnParallelLinear):
param_shard_size = self.kv_num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
def load_weight(self, state_dict: dict):

View File

@@ -91,7 +91,7 @@ class ParallelLMHead(nn.Layer):
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
},
)
if self.nranks > 1:
@@ -110,7 +110,7 @@ class ParallelLMHead(nn.Layer):
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
},
)

View File

@@ -216,18 +216,17 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.up_gate_proj_weight,
{
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
},
)
set_weight_attrs(
layer.down_proj_weight,
{
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
},
)

View File

@@ -1024,8 +1024,8 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
]
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
@@ -1037,7 +1037,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
@@ -1097,7 +1097,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)

View File

@@ -57,7 +57,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
]
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
@@ -69,6 +70,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
@@ -127,6 +129,25 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(
getattr(layer, up_gate_proj_weight_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, up_gate_proj_scale_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_weight_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_scale_name),
extra_weight_attrs,
)
def process_weights_after_loading(self, layer):
""" """
@@ -169,6 +190,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
)
weight[expert_id].copy_(weight_quant, False)
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
# create weight

View File

@@ -72,7 +72,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.moe_intermediate_size,
layer.hidden_size,
]
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
@@ -84,6 +85,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
@@ -136,6 +139,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# support cache feature in future
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
@@ -723,7 +727,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
]
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
@@ -735,6 +740,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
@@ -794,6 +800,26 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
),
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(
getattr(layer, up_gate_proj_weight_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, up_gate_proj_scale_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_weight_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_scale_name),
extra_weight_attrs,
)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:

View File

@@ -206,20 +206,19 @@ class FusedMoE(nn.Layer):
if shard_id is None:
# 1.gate up fused in disk
model_format = getattr(param, "model_format", "")
is_torch_model = model_format == "torch"
weight_need_transpose = getattr(param, "weight_need_transpose", False)
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
per_rank = output_size // 2
start = self.tp_rank * per_rank
loaded_weight_shard_gate = slice_fn(
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
)
self._load_gate_up_weight(
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
)
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
loaded_weight_shard_up = slice_fn(
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
)
self._load_gate_up_weight(
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
@@ -236,10 +235,9 @@ class FusedMoE(nn.Layer):
)
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
model_format = getattr(param, "model_format", "")
is_torch_model = model_format == "torch"
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.tp_size > 1 and not is_sharded:
tp_shard_dim = is_torch_model ^ shard_dim
tp_shard_dim = weight_need_transpose ^ shard_dim
weight_dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[weight_dim]
@@ -275,13 +273,17 @@ class FusedMoE(nn.Layer):
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
if expert_param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(expert_param.dtype)
else:
loaded_weight = loaded_weight.cast(expert_param.dtype)
expert_param.copy_(loaded_weight, False)
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
model_format = getattr(param, "model_format", "")
is_torch_model = model_format == "torch"
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.tp_size > 1 and shard_dim is not None:
tp_shard_dim = is_torch_model ^ shard_dim
tp_shard_dim = weight_need_transpose ^ shard_dim
dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
@@ -302,6 +304,11 @@ class FusedMoE(nn.Layer):
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
if expert_param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(expert_param.dtype)
else:
loaded_weight = loaded_weight.cast(expert_param.dtype)
expert_param.copy_(loaded_weight, False)
def _load_expert_weight(

View File

@@ -34,6 +34,72 @@ QUANTIZATION_METHODS: List[str] = [
]
def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
# 1.model_config.is_quantized
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
if model_config.model_format == "torch":
quantization_config = model_config.quantization_config
if quantization_config is not None:
model_config.is_quantized = True
else:
quantization_config = model_config.quantization_config
if not model_config.is_quantized:
if quantization_config is not None:
if "is_quantized" in quantization_config:
model_config.is_quantized = quantization_config["is_quantized"]
elif "kv_cache_quant_type" not in quantization_config:
model_config.is_quantized = True
if quantization_config is not None and quantization_config.get("quantization", None) is None:
raise ValueError(
"quantization_config should have a key named 'quantization' for specify quant config."
)
quant_config_name = None
if quantization_config is not None:
quant_config_name = _get_offline_quant_config_name(
quantization_config, model_config.model_format == "torch", is_v1_loader
)
elif args.quantization is not None:
quantization_config = {}
try:
quantization_config.update(args.quantization)
quant_config_name = quantization_config["quantization"]
except:
quant_config_name = args.quantization["quantization"]
quantization_config["quantization"] = quant_config_name
# Special handling for Ernie models
if quant_config_name == "wint4" and is_ernie:
quantization_config["dense_quant_type"] = "wint8"
quantization_config["moe_quant_type"] = "wint4"
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
else:
quant_config_name = None
if quant_config_name is None:
quant_config = None
else:
if not quantization_config.get("is_quantized"):
quantization_config["is_quantized"] = model_config.is_quantized
quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config)
return quant_config
def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_loader):
if is_torch_weight:
# only support block_wise_fp8 now
quant_method = quantization_config.get("quant_method")
has_block_size = "weight_block_size" in quantization_config
if quant_method == "fp8" and has_block_size:
quant_config_name = "block_wise_fp8"
else:
raise ValueError("Torch weight offline quantization only supports block-wise FP8.")
else:
quant_config_name = quantization_config["quantization"]
return quant_config_name
def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
"""
Get the quantization config class by the quantization name.

View File

@@ -53,7 +53,7 @@ class BlockWiseFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
weight_block_size = config.get("weight_block_size", [128, 128])
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(weight_block_size, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -89,13 +89,15 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs):
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
@@ -120,14 +122,28 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
layer.weight_scale_inv = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
(layer.weight_shape[0] + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
(layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.weight,
extra_weight_attrs,
)
set_weight_attrs(
layer.weight_scale_inv,
{
**extra_weight_attrs,
"is_scale": True,
},
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:

View File

@@ -37,7 +37,7 @@ class MixQuantConfig(QuantConfigBase):
is_channel_wise: bool = False,
has_zero_point: bool = False,
is_permuted: bool = True,
is_checkpoint_bf16: bool = False,
is_quantized: bool = False,
hadamard_block_size: int = 128,
) -> None:
super().__init__()
@@ -54,7 +54,8 @@ class MixQuantConfig(QuantConfigBase):
self.quant_min_bound = 0
self.quant_round_type = 0
self.is_permuted = is_permuted
self.is_checkpoint_bf16 = is_checkpoint_bf16
self.is_checkpoint_bf16 = not is_quantized
self.is_quantized = is_quantized
self.hadamard_block_size = hadamard_block_size
def name(self) -> str:
@@ -70,7 +71,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_channel_wise", False),
config.get("has_zero_point", False),
config.get("is_permuted", True),
config.get("is_checkpoint_bf16", False),
config.get("is_quantized", False),
config.get("hadamard_block_size", 128),
)
@@ -82,7 +83,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"is_checkpoint_bf16": self.is_checkpoint_bf16,
"is_quantized": self.is_quantized,
"hadamard_block_size": self.hadamard_block_size,
}
)
@@ -94,7 +95,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"is_checkpoint_bf16": self.is_checkpoint_bf16,
"is_quantized": self.is_quantized,
"hadamard_block_size": self.hadamard_block_size,
}
)
@@ -112,6 +113,6 @@ class MixQuantConfig(QuantConfigBase):
else:
return (
get_quantization_config(self.dense_quant_type)
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config({"is_quantized": self.is_quantized})
.get_quant_method(layer)
)

View File

@@ -65,7 +65,7 @@ class WeightOnlyConfig(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WeightOnlyConfig":
algo = config["algo"]
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(algo, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -162,7 +162,7 @@ class WINT8Config(WeightOnlyConfig):
@classmethod
def from_config(cls, config: dict) -> "WINT8Config":
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(is_checkpoint_bf16)
def name(self) -> str:
@@ -182,7 +182,7 @@ class WINT4Config(WeightOnlyConfig):
@classmethod
def from_config(cls, config: dict) -> "WINT4Config":
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(is_checkpoint_bf16)
def name(self) -> str:
@@ -202,13 +202,15 @@ class WeightOnlyLinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs):
if self.quant_config.is_checkpoint_bf16:
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if (
isinstance(layer, MergedColumnParallelLinear)
@@ -256,6 +258,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
{
"weight_loader": weight_loader,
"output_dim": output_dim,
"weight_need_transpose": not extra_weight_attrs.get("model_format") == "torch",
},
)

View File

@@ -60,7 +60,7 @@ class WFP8AFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WFP8AFP8Config":
""" """
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
is_checkpoint_bf16 = not config.get("is_quantized", False)
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -92,13 +92,14 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
)
scale_shape = scale_shape[::-1]
if self.quant_config.is_checkpoint_bf16:
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {

View File

@@ -98,7 +98,7 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
)
enable_cache = True
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
weight_cache_context = switch_config_context(fd_config.quant_config, "is_quantized", True)
return enable_cache, weight_cache_dir, weight_cache_context
@@ -150,7 +150,8 @@ def save_model(model_arg_name="model", config_arg_name="fd_config"):
)
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
else:
logger.info("Weights are already cached, skip saving")
reason = "weights already cached" if envs.FD_ENABLE_MODEL_LOAD_CACHE else "cache disabled"
logger.info(f"Skip saving ,{reason}")
return result
return wrapper

View File

@@ -527,6 +527,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
rename_offline_ckpt_suffix_to_fd_suffix,
)
general_params_mapping = [
@@ -564,15 +565,20 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
param_down_proj_name="experts.down_proj_",
num_experts_start_offset=num_experts_start_offset,
)
all_param_mapping = general_params_mapping + expert_params_mapping
all_param_mapping = [
(param, weight, exp, shard, False) for param, weight, exp, shard in general_params_mapping
] + [(param, weight, exp, shard, True) for param, weight, exp, shard in expert_params_mapping]
checkpoint_to_fd_key_fn = rename_offline_ckpt_suffix_to_fd_suffix(
fd_config=self.fd_config, ckpt_weight_suffix="quant_weight", ckpt_scale_suffix="weight_scale"
)
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
loaded_weight_name = loaded_weight_name.replace("model", "ernie")
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
for param_name, weight_name, exp_id, shard_id, is_moe in all_param_mapping:
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe)
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
@@ -583,6 +589,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
else:
expert_id = None
shard_id = None
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe=False)
model_param_name = loaded_weight_name
if model_param_name not in params_dict.keys():
continue

View File

@@ -193,16 +193,16 @@ class VisionFlashAttention2(nn.Layer):
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim)
set_weight_attrs(self.qkv.weight, {"model_format": model_format})
set_weight_attrs(self.proj.weight, {"model_format": model_format})
set_weight_attrs(self.qkv.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.proj.weight, {"weight_need_transpose": model_format == "torch"})
self.head_dim = dim // num_heads # must added
self.num_heads = num_heads
self.hidden_size = dim
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
load_bias = getattr(param, "load_bias", None)
if load_bias:
@@ -358,8 +358,8 @@ class VisionMlp(nn.Layer):
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
set_weight_attrs(self.fc1.weight, {"model_format": model_format})
set_weight_attrs(self.fc2.weight, {"model_format": model_format})
set_weight_attrs(self.fc1.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.fc2.weight, {"weight_need_transpose": model_format == "torch"})
self.act = ACT2FN[hidden_act]
@@ -528,8 +528,10 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
in_channels=config.vision_config.in_channels,
embed_dim=config.vision_config.embed_dim,
)
model_format = getattr(config, "model_format", "")
set_weight_attrs(self.patch_embed.proj.weight, {"model_format": model_format})
set_weight_attrs(self.patch_embed.proj.weight, {"weight_need_transpose": model_format == "torch"})
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)

View File

@@ -181,8 +181,8 @@ class VariableResolutionResamplerModel(nn.Layer):
nn.Linear(self.spatial_dim, self.spatial_dim),
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
)
set_weight_attrs(self.spatial_linear[0].weight, {"model_format": config.model_format})
set_weight_attrs(self.spatial_linear[2].weight, {"model_format": config.model_format})
set_weight_attrs(self.spatial_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"})
set_weight_attrs(self.spatial_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"})
if self.use_temporal_conv:
self.temporal_linear = nn.Sequential(
@@ -191,12 +191,16 @@ class VariableResolutionResamplerModel(nn.Layer):
nn.Linear(self.spatial_dim, self.spatial_dim),
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
)
set_weight_attrs(self.temporal_linear[0].weight, {"model_format": config.model_format})
set_weight_attrs(self.temporal_linear[2].weight, {"model_format": config.model_format})
set_weight_attrs(
self.temporal_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"}
)
set_weight_attrs(
self.temporal_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"}
)
self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
set_weight_attrs(self.mlp.weight, {"model_format": config.model_format})
set_weight_attrs(self.mlp.weight, {"weight_need_transpose": config.model_format == "torch"})
out_config = deepcopy(config)
out_config.hidden_size = out_dim

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import re
from contextlib import contextmanager
from typing import Any, Optional, Union
@@ -158,8 +159,8 @@ def default_weight_loader(fd_config: FDConfig) -> None:
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
"""fn"""
output_dim = getattr(param, "output_dim", None)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
@@ -177,6 +178,9 @@ def default_weight_loader(fd_config: FDConfig) -> None:
loaded_weight = get_tensor(loaded_weight)
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
if param.shape != loaded_weight.shape:
# for e_score_correction_bias
@@ -210,3 +214,50 @@ def switch_config_context(config_obj, config_attr_name, value):
yield
finally:
setattr(config_obj, config_attr_name, origin_value)
def rename_offline_ckpt_suffix_to_fd_suffix(
fd_config, ckpt_weight_suffix: str = "quant_weight", ckpt_scale_suffix="weight_scale"
):
"""
Create a function to rename checkpoint key suffixes for FastDeploy.
Replaces the original suffix (default "weight_scale") with the FD target
suffix (default "quant_weight"). Only the suffix is changed.
Args:
fd_config: FastDeploy configuration.
ckpt_weight_suffix: Original checkpoint key suffix.
ckpt_scale_suffix: Target FastDeploy key suffix.
Returns:
Callable: Function that renames checkpoint keys.
"""
fd_suffix_map = {} # noqa: F841
fp8_suffix_map = {
ckpt_weight_suffix: "weight",
ckpt_scale_suffix: "weight_scale_inv",
}
moe_quant_type = ""
dense_quant_type = ""
if fd_config.quant_config is None:
if fd_config.quant_config.name() == "mix_quant":
moe_quant_type = fd_config.quant_config.moe_quant_type
dense_quant_type = fd_config.quant_config.dense_quant_type
else:
moe_quant_type = fd_config.quant_config.name()
dense_quant_type = fd_config.quant_config.name()
def fn(loaded_weight_name, is_moe):
if fd_config.quant_config is None or fd_config.quant_config.is_checkpoint_bf16:
return loaded_weight_name
# Can be extended to other offline quantization suffixes if needed.
if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"):
fd_suffix_map = fp8_suffix_map
for ckpt_suffix, fd_suffix in fd_suffix_map.items():
if re.search(rf"{ckpt_suffix}$", loaded_weight_name):
loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix)
return loaded_weight_name
return loaded_weight_name
return fn

View File

@@ -42,7 +42,7 @@ from fastdeploy.config import (
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import get_quantization_config
from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, parse_quantization
from fastdeploy.worker.worker_base import WorkerBase
@@ -698,50 +698,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
if getattr(model_config, "num_hidden_layers", None) is None:
raise ValueError("num_hidden_layers is None")
quantization_config = model_config.quantization_config
if not model_config.is_quantized:
if quantization_config is not None:
if "is_quantized" in quantization_config:
model_config.is_quantized = quantization_config["is_quantized"]
elif "kv_cache_quant_type" not in quantization_config:
model_config.is_quantized = True
quant_config_name = None
if quantization_config is not None and quantization_config.get("quantization", None) is None:
raise ValueError("quantization_config should have a key named 'quantization' for specify quant config.")
if quantization_config is not None:
quant_config_name = quantization_config["quantization"]
# TODO(YuanRisheng) is_checkpoint_bf16 may need to be removed and replaced by is_quantized in future
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
quantization_config["is_checkpoint_bf16"] = True
elif args.quantization is not None:
quantization_config = {}
try:
quantization_config.update(args.quantization)
quant_config_name = quantization_config["quantization"]
except:
quant_config_name = args.quantization["quantization"]
quantization_config["quantization"] = quant_config_name
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
if load_config.load_choices == "default_v1":
quantization_config["is_checkpoint_bf16"] = True
# Special handling for Ernie models
is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)
if quant_config_name == "wint4" and is_ernie:
quantization_config["dense_quant_type"] = "wint8"
quantization_config["moe_quant_type"] = "wint4"
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
else:
quant_config_name = None
if quant_config_name is None:
quant_config = None
else:
quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config)
quant_config = parse_quant_config(
args,
model_config,
is_ernie=ErnieArchitectures.contains_ernie_arch(model_config.architectures),
is_v1_loader=load_config.load_choices == "default_v1",
)
# Log quantization info
logger.info("===========quantization_config==============")
@@ -751,7 +713,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
else:
logger.info("Model Status: Original (will apply online quantization)")
logger.info(f"{quantization_config}")
logger.info(f"{model_config.quantization_config}")
else:
logger.info("No quantization config found and use original weight and act dtype.")

View File

@@ -53,12 +53,8 @@ class FDRunner:
req_outputs = self.llm.generate(prompts, sampling_params=sampling_params, **kwargs)
outputs: list[tuple[list[list[int]], list[str]]] = []
sample_output_ids: list[list[int]] = []
sample_output_strs: list[str] = []
for output in req_outputs:
sample_output_ids.append(output.outputs.token_ids)
sample_output_strs.append(output.outputs.text)
outputs.append((sample_output_ids, sample_output_strs))
outputs.append((output.outputs.token_ids, output.outputs.text))
return outputs
def generate_topp0(
@@ -69,7 +65,7 @@ class FDRunner:
) -> list[tuple[list[int], str]]:
from fastdeploy.engine.sampling_params import SamplingParams
topp_params = SamplingParams(temperature=0.1, top_p=0, max_tokens=max_tokens)
topp_params = SamplingParams(temperature=0.0, top_p=0, max_tokens=max_tokens)
outputs = self.generate(prompts, topp_params, **kwargs)
return outputs

View File

@@ -0,0 +1,104 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import pytest
prompts = ["解释下'温故而知新'", "who are you?"]
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tests.model_loader.utils import (
form_model_get_output_topp0,
get_torch_model_path,
run_with_timeout,
)
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
model_param_map = {
"Qwen3-30B-A3B-FP8": {
"tensor_parallel_size": 2,
"quantizations": [
{
"quant_type": "None",
"backend": "triton",
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
},
],
},
}
params = []
for model, cfg in model_param_map.items():
for q in cfg["quantizations"]:
if isinstance(q, dict):
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
else:
quant, backend, env = q, "default", {}
params.append(
pytest.param(
model,
cfg.get("tensor_parallel_size", 1),
cfg.get("max_model_len", 1024),
quant,
cfg.get("max_tokens", 32),
env,
marks=[pytest.mark.core_model],
id=f"offline_quant_{model}.{quant}.{backend}",
)
)
@pytest.mark.parametrize(
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
params,
)
def test_offline_model(
fd_runner,
model_name_or_path: str,
tensor_parallel_size: int,
max_model_len: int,
max_tokens: int,
quantization: str,
env,
monkeypatch,
) -> None:
torch_model_path = get_torch_model_path(model_name_or_path)
if env:
for k, v in env.items():
monkeypatch.setenv(k, v)
_ = run_with_timeout(
target=form_model_get_output_topp0,
args=(
fd_runner,
torch_model_path,
tensor_parallel_size,
max_model_len,
max_tokens,
quantization,
"default_v1",
FD_ENGINE_QUEUE_PORT,
prompts,
FD_CACHE_QUEUE_PORT,
),
)

View File

@@ -181,7 +181,7 @@ def check_tokens_id_and_text_close(
outputs_1_lst: TokensIdText,
name_0: str,
name_1: str,
warn_on_mismatch: bool = True,
threshold: float = 0.0,
) -> None:
assert len(outputs_0_lst) == len(outputs_1_lst)
@@ -190,10 +190,21 @@ def check_tokens_id_and_text_close(
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
if threshold > 0:
diff_rate = calculate_diff_rate(output_str_0, output_str_1)
if diff_rate >= threshold:
fail_msg = (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}"
f"\nDiff rate: {diff_rate:.4f} >= threshold: {threshold}"
)
raise AssertionError(fail_msg)
else:
# Loop through generated tokens.
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
is_tok_mismatch = output_id_0 != output_id_1
if is_tok_mismatch and warn_on_mismatch:
if is_tok_mismatch:
fail_msg = (
f"Test{prompt_idx}:"
f"\nMatched tokens:\t{output_ids_0[:idx]}"
@@ -201,10 +212,6 @@ def check_tokens_id_and_text_close(
f"\n{name_1}:\t{output_str_1!r}"
)
raise AssertionError(fail_msg)
else:
if output_str_0 != output_str_1 and warn_on_mismatch:
fail_msg = f"Test{prompt_idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}"
raise AssertionError(fail_msg)
def calculate_diff_rate(text1, text2):