mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[v1 loader]qwen Offline fp8 (#4036)
* support offline fp8 * update ut * update ut * update ut * fix * update * update
This commit is contained in:
@@ -57,7 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -341,10 +341,10 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
self.output_sizes = output_sizes
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if model_format == "torch":
|
||||
if weight_need_transpose:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
|
||||
assert loaded_shard_id in ["q_a", "kv_a"]
|
||||
@@ -365,6 +365,12 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
|
||||
@@ -483,15 +489,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
shard_dim = -1 if output_dim else 0
|
||||
output_size = param.shape[shard_dim]
|
||||
if loaded_shard_id is None:
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
param.weight_need_transpose = False
|
||||
# Loaded weight is already fused on disk.
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
@@ -506,6 +513,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
else:
|
||||
# split gate up
|
||||
assert loaded_shard_id in ["gate", "up"]
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
dim = -1 if output_dim else 0
|
||||
@@ -532,6 +542,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
@@ -604,11 +620,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
add_bias=add_bias,
|
||||
)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads_per_rank * self.head_dim,
|
||||
"k": self.kv_num_heads_per_rank * self.head_dim,
|
||||
"v": self.kv_num_heads_per_rank * self.head_dim,
|
||||
"q": self.num_heads_per_rank * head_dim,
|
||||
"k": self.kv_num_heads_per_rank * head_dim,
|
||||
"v": self.kv_num_heads_per_rank * head_dim,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
@@ -617,11 +633,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
assert output_dim is not None
|
||||
dim = -1 if output_dim else 0
|
||||
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if loaded_shard_id is None:
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
param.weight_need_transpose = False
|
||||
# Loaded weight is already fused on disk
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
@@ -637,13 +654,16 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
else:
|
||||
# split q k v
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
shard_size = (shard_id + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
shard_size = block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
@@ -663,10 +683,17 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_shard_size = self.kv_num_heads_per_rank * head_dim
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
|
||||
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
|
@@ -91,7 +91,7 @@ class ParallelLMHead(nn.Layer):
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
@@ -110,7 +110,7 @@ class ParallelLMHead(nn.Layer):
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
|
||||
},
|
||||
)
|
||||
|
||||
|
@@ -216,18 +216,17 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||
},
|
||||
)
|
||||
|
@@ -1024,8 +1024,8 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
]
|
||||
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
||||
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
||||
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
@@ -1037,7 +1037,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
@@ -1097,7 +1097,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
|
||||
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
|
||||
|
@@ -57,7 +57,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
@@ -69,6 +70,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
@@ -127,6 +129,25 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||
set_weight_attrs(
|
||||
getattr(layer, up_gate_proj_weight_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
set_weight_attrs(
|
||||
getattr(layer, up_gate_proj_scale_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
getattr(layer, down_proj_weight_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
set_weight_attrs(
|
||||
getattr(layer, down_proj_scale_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
@@ -169,6 +190,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
|
||||
)
|
||||
weight[expert_id].copy_(weight_quant, False)
|
||||
|
||||
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
|
@@ -72,7 +72,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
@@ -84,6 +85,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
@@ -136,6 +139,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# support cache feature in future
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -723,7 +727,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
@@ -735,6 +740,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
@@ -794,6 +800,26 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
),
|
||||
)
|
||||
|
||||
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||
set_weight_attrs(
|
||||
getattr(layer, up_gate_proj_weight_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
set_weight_attrs(
|
||||
getattr(layer, up_gate_proj_scale_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
getattr(layer, down_proj_weight_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
set_weight_attrs(
|
||||
getattr(layer, down_proj_scale_name),
|
||||
extra_weight_attrs,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
|
@@ -206,20 +206,19 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
model_format = getattr(param, "model_format", "")
|
||||
is_torch_model = model_format == "torch"
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
|
||||
per_rank = output_size // 2
|
||||
start = self.tp_rank * per_rank
|
||||
loaded_weight_shard_gate = slice_fn(
|
||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
|
||||
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
|
||||
)
|
||||
self._load_gate_up_weight(
|
||||
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
|
||||
)
|
||||
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
|
||||
loaded_weight_shard_up = slice_fn(
|
||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
|
||||
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
|
||||
)
|
||||
self._load_gate_up_weight(
|
||||
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
|
||||
@@ -236,10 +235,9 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
|
||||
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
is_torch_model = model_format == "torch"
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if self.tp_size > 1 and not is_sharded:
|
||||
tp_shard_dim = is_torch_model ^ shard_dim
|
||||
tp_shard_dim = weight_need_transpose ^ shard_dim
|
||||
weight_dim = -1 if tp_shard_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[weight_dim]
|
||||
@@ -275,13 +273,17 @@ class FusedMoE(nn.Layer):
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
if expert_param.dtype != loaded_weight.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(expert_param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(expert_param.dtype)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
|
||||
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
is_torch_model = model_format == "torch"
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if self.tp_size > 1 and shard_dim is not None:
|
||||
tp_shard_dim = is_torch_model ^ shard_dim
|
||||
tp_shard_dim = weight_need_transpose ^ shard_dim
|
||||
dim = -1 if tp_shard_dim else 0
|
||||
if isinstance(loaded_weight, paddle.Tensor):
|
||||
size = loaded_weight.shape[dim]
|
||||
@@ -302,6 +304,11 @@ class FusedMoE(nn.Layer):
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
if expert_param.dtype != loaded_weight.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(expert_param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(expert_param.dtype)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
|
||||
def _load_expert_weight(
|
||||
|
@@ -34,6 +34,72 @@ QUANTIZATION_METHODS: List[str] = [
|
||||
]
|
||||
|
||||
|
||||
def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
||||
# 1.model_config.is_quantized
|
||||
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
|
||||
if model_config.model_format == "torch":
|
||||
quantization_config = model_config.quantization_config
|
||||
if quantization_config is not None:
|
||||
model_config.is_quantized = True
|
||||
else:
|
||||
quantization_config = model_config.quantization_config
|
||||
if not model_config.is_quantized:
|
||||
if quantization_config is not None:
|
||||
if "is_quantized" in quantization_config:
|
||||
model_config.is_quantized = quantization_config["is_quantized"]
|
||||
elif "kv_cache_quant_type" not in quantization_config:
|
||||
model_config.is_quantized = True
|
||||
if quantization_config is not None and quantization_config.get("quantization", None) is None:
|
||||
raise ValueError(
|
||||
"quantization_config should have a key named 'quantization' for specify quant config."
|
||||
)
|
||||
|
||||
quant_config_name = None
|
||||
|
||||
if quantization_config is not None:
|
||||
quant_config_name = _get_offline_quant_config_name(
|
||||
quantization_config, model_config.model_format == "torch", is_v1_loader
|
||||
)
|
||||
elif args.quantization is not None:
|
||||
quantization_config = {}
|
||||
try:
|
||||
quantization_config.update(args.quantization)
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
except:
|
||||
quant_config_name = args.quantization["quantization"]
|
||||
quantization_config["quantization"] = quant_config_name
|
||||
# Special handling for Ernie models
|
||||
if quant_config_name == "wint4" and is_ernie:
|
||||
quantization_config["dense_quant_type"] = "wint8"
|
||||
quantization_config["moe_quant_type"] = "wint4"
|
||||
quantization_config["quantization"] = "mix_quant"
|
||||
quant_config_name = "mix_quant"
|
||||
else:
|
||||
quant_config_name = None
|
||||
if quant_config_name is None:
|
||||
quant_config = None
|
||||
else:
|
||||
if not quantization_config.get("is_quantized"):
|
||||
quantization_config["is_quantized"] = model_config.is_quantized
|
||||
quant_cls = get_quantization_config(quant_config_name)
|
||||
quant_config = quant_cls.from_config(quantization_config)
|
||||
return quant_config
|
||||
|
||||
|
||||
def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_loader):
|
||||
if is_torch_weight:
|
||||
# only support block_wise_fp8 now
|
||||
quant_method = quantization_config.get("quant_method")
|
||||
has_block_size = "weight_block_size" in quantization_config
|
||||
if quant_method == "fp8" and has_block_size:
|
||||
quant_config_name = "block_wise_fp8"
|
||||
else:
|
||||
raise ValueError("Torch weight offline quantization only supports block-wise FP8.")
|
||||
else:
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
return quant_config_name
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
|
||||
"""
|
||||
Get the quantization config class by the quantization name.
|
||||
|
@@ -53,7 +53,7 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
|
||||
weight_block_size = config.get("weight_block_size", [128, 128])
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||
return cls(weight_block_size, is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -89,13 +89,15 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
quant_attrs = extra_weight_attrs
|
||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||
quant_attrs = {
|
||||
@@ -120,14 +122,28 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
layer.weight_scale_inv = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
||||
(layer.weight_shape[0] + self.quant_config.weight_block_size[0] - 1)
|
||||
// self.quant_config.weight_block_size[0],
|
||||
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
|
||||
(layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1)
|
||||
// self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
|
||||
|
||||
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||
set_weight_attrs(
|
||||
layer.weight,
|
||||
extra_weight_attrs,
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.weight_scale_inv,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"is_scale": True,
|
||||
},
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
|
@@ -37,7 +37,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
is_channel_wise: bool = False,
|
||||
has_zero_point: bool = False,
|
||||
is_permuted: bool = True,
|
||||
is_checkpoint_bf16: bool = False,
|
||||
is_quantized: bool = False,
|
||||
hadamard_block_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -54,7 +54,8 @@ class MixQuantConfig(QuantConfigBase):
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
self.is_permuted = is_permuted
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
self.is_checkpoint_bf16 = not is_quantized
|
||||
self.is_quantized = is_quantized
|
||||
self.hadamard_block_size = hadamard_block_size
|
||||
|
||||
def name(self) -> str:
|
||||
@@ -70,7 +71,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config.get("is_channel_wise", False),
|
||||
config.get("has_zero_point", False),
|
||||
config.get("is_permuted", True),
|
||||
config.get("is_checkpoint_bf16", False),
|
||||
config.get("is_quantized", False),
|
||||
config.get("hadamard_block_size", 128),
|
||||
)
|
||||
|
||||
@@ -82,7 +83,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
.from_config(
|
||||
{
|
||||
"is_permuted": self.is_permuted,
|
||||
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"is_quantized": self.is_quantized,
|
||||
"hadamard_block_size": self.hadamard_block_size,
|
||||
}
|
||||
)
|
||||
@@ -94,7 +95,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
.from_config(
|
||||
{
|
||||
"is_permuted": self.is_permuted,
|
||||
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"is_quantized": self.is_quantized,
|
||||
"hadamard_block_size": self.hadamard_block_size,
|
||||
}
|
||||
)
|
||||
@@ -112,6 +113,6 @@ class MixQuantConfig(QuantConfigBase):
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.dense_quant_type)
|
||||
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.from_config({"is_quantized": self.is_quantized})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
|
@@ -65,7 +65,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
||||
algo = config["algo"]
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||
return cls(algo, is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -162,7 +162,7 @@ class WINT8Config(WeightOnlyConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT8Config":
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||
return cls(is_checkpoint_bf16)
|
||||
|
||||
def name(self) -> str:
|
||||
@@ -182,7 +182,7 @@ class WINT4Config(WeightOnlyConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT4Config":
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||
return cls(is_checkpoint_bf16)
|
||||
|
||||
def name(self) -> str:
|
||||
@@ -202,13 +202,15 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
quant_attrs = extra_weight_attrs
|
||||
if (
|
||||
isinstance(layer, MergedColumnParallelLinear)
|
||||
@@ -256,6 +258,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
{
|
||||
"weight_loader": weight_loader,
|
||||
"output_dim": output_dim,
|
||||
"weight_need_transpose": not extra_weight_attrs.get("model_format") == "torch",
|
||||
},
|
||||
)
|
||||
|
||||
|
@@ -60,7 +60,7 @@ class WFP8AFP8Config(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
||||
""" """
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -92,13 +92,14 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
|
||||
)
|
||||
scale_shape = scale_shape[::-1]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
quant_attrs = extra_weight_attrs
|
||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||
quant_attrs = {
|
||||
|
@@ -98,7 +98,7 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
||||
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
|
||||
)
|
||||
enable_cache = True
|
||||
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
|
||||
weight_cache_context = switch_config_context(fd_config.quant_config, "is_quantized", True)
|
||||
|
||||
return enable_cache, weight_cache_dir, weight_cache_context
|
||||
|
||||
@@ -150,7 +150,8 @@ def save_model(model_arg_name="model", config_arg_name="fd_config"):
|
||||
)
|
||||
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
|
||||
else:
|
||||
logger.info("Weights are already cached, skip saving")
|
||||
reason = "weights already cached" if envs.FD_ENABLE_MODEL_LOAD_CACHE else "cache disabled"
|
||||
logger.info(f"Skip saving ,{reason}")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
@@ -527,6 +527,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
process_weights_after_loading,
|
||||
rename_offline_ckpt_suffix_to_fd_suffix,
|
||||
)
|
||||
|
||||
general_params_mapping = [
|
||||
@@ -564,15 +565,20 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
param_down_proj_name="experts.down_proj_",
|
||||
num_experts_start_offset=num_experts_start_offset,
|
||||
)
|
||||
all_param_mapping = general_params_mapping + expert_params_mapping
|
||||
|
||||
all_param_mapping = [
|
||||
(param, weight, exp, shard, False) for param, weight, exp, shard in general_params_mapping
|
||||
] + [(param, weight, exp, shard, True) for param, weight, exp, shard in expert_params_mapping]
|
||||
checkpoint_to_fd_key_fn = rename_offline_ckpt_suffix_to_fd_suffix(
|
||||
fd_config=self.fd_config, ckpt_weight_suffix="quant_weight", ckpt_scale_suffix="weight_scale"
|
||||
)
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
loaded_weight_name = loaded_weight_name.replace("model", "ernie")
|
||||
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
||||
for param_name, weight_name, exp_id, shard_id, is_moe in all_param_mapping:
|
||||
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe)
|
||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||
if model_param_name not in params_dict:
|
||||
continue
|
||||
@@ -583,6 +589,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
else:
|
||||
expert_id = None
|
||||
shard_id = None
|
||||
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe=False)
|
||||
model_param_name = loaded_weight_name
|
||||
if model_param_name not in params_dict.keys():
|
||||
continue
|
||||
|
@@ -193,16 +193,16 @@ class VisionFlashAttention2(nn.Layer):
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
set_weight_attrs(self.qkv.weight, {"model_format": model_format})
|
||||
set_weight_attrs(self.proj.weight, {"model_format": model_format})
|
||||
set_weight_attrs(self.qkv.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
set_weight_attrs(self.proj.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
self.head_dim = dim // num_heads # must added
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = dim
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if weight_need_transpose:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
load_bias = getattr(param, "load_bias", None)
|
||||
if load_bias:
|
||||
@@ -358,8 +358,8 @@ class VisionMlp(nn.Layer):
|
||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||
|
||||
set_weight_attrs(self.fc1.weight, {"model_format": model_format})
|
||||
set_weight_attrs(self.fc2.weight, {"model_format": model_format})
|
||||
set_weight_attrs(self.fc1.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
set_weight_attrs(self.fc2.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
|
||||
self.act = ACT2FN[hidden_act]
|
||||
|
||||
@@ -528,8 +528,10 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
in_channels=config.vision_config.in_channels,
|
||||
embed_dim=config.vision_config.embed_dim,
|
||||
)
|
||||
|
||||
model_format = getattr(config, "model_format", "")
|
||||
set_weight_attrs(self.patch_embed.proj.weight, {"model_format": model_format})
|
||||
|
||||
set_weight_attrs(self.patch_embed.proj.weight, {"weight_need_transpose": model_format == "torch"})
|
||||
|
||||
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||
|
@@ -181,8 +181,8 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
nn.Linear(self.spatial_dim, self.spatial_dim),
|
||||
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
||||
)
|
||||
set_weight_attrs(self.spatial_linear[0].weight, {"model_format": config.model_format})
|
||||
set_weight_attrs(self.spatial_linear[2].weight, {"model_format": config.model_format})
|
||||
set_weight_attrs(self.spatial_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"})
|
||||
set_weight_attrs(self.spatial_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"})
|
||||
|
||||
if self.use_temporal_conv:
|
||||
self.temporal_linear = nn.Sequential(
|
||||
@@ -191,12 +191,16 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
nn.Linear(self.spatial_dim, self.spatial_dim),
|
||||
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
||||
)
|
||||
set_weight_attrs(self.temporal_linear[0].weight, {"model_format": config.model_format})
|
||||
set_weight_attrs(self.temporal_linear[2].weight, {"model_format": config.model_format})
|
||||
set_weight_attrs(
|
||||
self.temporal_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"}
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.temporal_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"}
|
||||
)
|
||||
|
||||
self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
|
||||
|
||||
set_weight_attrs(self.mlp.weight, {"model_format": config.model_format})
|
||||
set_weight_attrs(self.mlp.weight, {"weight_need_transpose": config.model_format == "torch"})
|
||||
|
||||
out_config = deepcopy(config)
|
||||
out_config.hidden_size = out_dim
|
||||
|
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -158,8 +159,8 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
||||
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
|
||||
"""fn"""
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
@@ -177,7 +178,10 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||
if param.dtype != loaded_weight.dtype:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
if param.shape != loaded_weight.shape:
|
||||
# for e_score_correction_bias
|
||||
loaded_weight = loaded_weight.reshape(param.shape)
|
||||
@@ -210,3 +214,50 @@ def switch_config_context(config_obj, config_attr_name, value):
|
||||
yield
|
||||
finally:
|
||||
setattr(config_obj, config_attr_name, origin_value)
|
||||
|
||||
|
||||
def rename_offline_ckpt_suffix_to_fd_suffix(
|
||||
fd_config, ckpt_weight_suffix: str = "quant_weight", ckpt_scale_suffix="weight_scale"
|
||||
):
|
||||
"""
|
||||
Create a function to rename checkpoint key suffixes for FastDeploy.
|
||||
|
||||
Replaces the original suffix (default "weight_scale") with the FD target
|
||||
suffix (default "quant_weight"). Only the suffix is changed.
|
||||
|
||||
Args:
|
||||
fd_config: FastDeploy configuration.
|
||||
ckpt_weight_suffix: Original checkpoint key suffix.
|
||||
ckpt_scale_suffix: Target FastDeploy key suffix.
|
||||
|
||||
Returns:
|
||||
Callable: Function that renames checkpoint keys.
|
||||
"""
|
||||
fd_suffix_map = {} # noqa: F841
|
||||
fp8_suffix_map = {
|
||||
ckpt_weight_suffix: "weight",
|
||||
ckpt_scale_suffix: "weight_scale_inv",
|
||||
}
|
||||
moe_quant_type = ""
|
||||
dense_quant_type = ""
|
||||
if fd_config.quant_config is None:
|
||||
if fd_config.quant_config.name() == "mix_quant":
|
||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||
dense_quant_type = fd_config.quant_config.dense_quant_type
|
||||
else:
|
||||
moe_quant_type = fd_config.quant_config.name()
|
||||
dense_quant_type = fd_config.quant_config.name()
|
||||
|
||||
def fn(loaded_weight_name, is_moe):
|
||||
if fd_config.quant_config is None or fd_config.quant_config.is_checkpoint_bf16:
|
||||
return loaded_weight_name
|
||||
# Can be extended to other offline quantization suffixes if needed.
|
||||
if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"):
|
||||
fd_suffix_map = fp8_suffix_map
|
||||
for ckpt_suffix, fd_suffix in fd_suffix_map.items():
|
||||
if re.search(rf"{ckpt_suffix}$", loaded_weight_name):
|
||||
loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix)
|
||||
return loaded_weight_name
|
||||
return loaded_weight_name
|
||||
|
||||
return fn
|
||||
|
@@ -42,7 +42,7 @@ from fastdeploy.config import (
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.model_executor.layers.quantization import get_quantization_config
|
||||
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import get_logger, parse_quantization
|
||||
from fastdeploy.worker.worker_base import WorkerBase
|
||||
@@ -698,50 +698,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
if getattr(model_config, "num_hidden_layers", None) is None:
|
||||
raise ValueError("num_hidden_layers is None")
|
||||
|
||||
quantization_config = model_config.quantization_config
|
||||
if not model_config.is_quantized:
|
||||
if quantization_config is not None:
|
||||
if "is_quantized" in quantization_config:
|
||||
model_config.is_quantized = quantization_config["is_quantized"]
|
||||
elif "kv_cache_quant_type" not in quantization_config:
|
||||
model_config.is_quantized = True
|
||||
|
||||
quant_config_name = None
|
||||
if quantization_config is not None and quantization_config.get("quantization", None) is None:
|
||||
raise ValueError("quantization_config should have a key named 'quantization' for specify quant config.")
|
||||
|
||||
if quantization_config is not None:
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
# TODO(YuanRisheng) is_checkpoint_bf16 may need to be removed and replaced by is_quantized in future
|
||||
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
|
||||
quantization_config["is_checkpoint_bf16"] = True
|
||||
|
||||
elif args.quantization is not None:
|
||||
quantization_config = {}
|
||||
try:
|
||||
quantization_config.update(args.quantization)
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
except:
|
||||
quant_config_name = args.quantization["quantization"]
|
||||
quantization_config["quantization"] = quant_config_name
|
||||
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
|
||||
if load_config.load_choices == "default_v1":
|
||||
quantization_config["is_checkpoint_bf16"] = True
|
||||
# Special handling for Ernie models
|
||||
is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)
|
||||
if quant_config_name == "wint4" and is_ernie:
|
||||
quantization_config["dense_quant_type"] = "wint8"
|
||||
quantization_config["moe_quant_type"] = "wint4"
|
||||
quantization_config["quantization"] = "mix_quant"
|
||||
quant_config_name = "mix_quant"
|
||||
else:
|
||||
quant_config_name = None
|
||||
|
||||
if quant_config_name is None:
|
||||
quant_config = None
|
||||
else:
|
||||
quant_cls = get_quantization_config(quant_config_name)
|
||||
quant_config = quant_cls.from_config(quantization_config)
|
||||
quant_config = parse_quant_config(
|
||||
args,
|
||||
model_config,
|
||||
is_ernie=ErnieArchitectures.contains_ernie_arch(model_config.architectures),
|
||||
is_v1_loader=load_config.load_choices == "default_v1",
|
||||
)
|
||||
|
||||
# Log quantization info
|
||||
logger.info("===========quantization_config==============")
|
||||
@@ -751,7 +713,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
else:
|
||||
logger.info("Model Status: Original (will apply online quantization)")
|
||||
|
||||
logger.info(f"{quantization_config}")
|
||||
logger.info(f"{model_config.quantization_config}")
|
||||
else:
|
||||
logger.info("No quantization config found and use original weight and act dtype.")
|
||||
|
||||
|
@@ -53,12 +53,8 @@ class FDRunner:
|
||||
|
||||
req_outputs = self.llm.generate(prompts, sampling_params=sampling_params, **kwargs)
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
sample_output_ids: list[list[int]] = []
|
||||
sample_output_strs: list[str] = []
|
||||
for output in req_outputs:
|
||||
sample_output_ids.append(output.outputs.token_ids)
|
||||
sample_output_strs.append(output.outputs.text)
|
||||
outputs.append((sample_output_ids, sample_output_strs))
|
||||
outputs.append((output.outputs.token_ids, output.outputs.text))
|
||||
return outputs
|
||||
|
||||
def generate_topp0(
|
||||
@@ -69,7 +65,7 @@ class FDRunner:
|
||||
) -> list[tuple[list[int], str]]:
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
topp_params = SamplingParams(temperature=0.1, top_p=0, max_tokens=max_tokens)
|
||||
topp_params = SamplingParams(temperature=0.0, top_p=0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, topp_params, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
104
tests/model_loader/test_offline_model.py
Normal file
104
tests/model_loader/test_offline_model.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
prompts = ["解释下'温故而知新'", "who are you?"]
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from tests.model_loader.utils import (
|
||||
form_model_get_output_topp0,
|
||||
get_torch_model_path,
|
||||
run_with_timeout,
|
||||
)
|
||||
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
|
||||
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||
|
||||
|
||||
model_param_map = {
|
||||
"Qwen3-30B-A3B-FP8": {
|
||||
"tensor_parallel_size": 2,
|
||||
"quantizations": [
|
||||
{
|
||||
"quant_type": "None",
|
||||
"backend": "triton",
|
||||
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
params = []
|
||||
for model, cfg in model_param_map.items():
|
||||
for q in cfg["quantizations"]:
|
||||
if isinstance(q, dict):
|
||||
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
|
||||
else:
|
||||
quant, backend, env = q, "default", {}
|
||||
params.append(
|
||||
pytest.param(
|
||||
model,
|
||||
cfg.get("tensor_parallel_size", 1),
|
||||
cfg.get("max_model_len", 1024),
|
||||
quant,
|
||||
cfg.get("max_tokens", 32),
|
||||
env,
|
||||
marks=[pytest.mark.core_model],
|
||||
id=f"offline_quant_{model}.{quant}.{backend}",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
|
||||
params,
|
||||
)
|
||||
def test_offline_model(
|
||||
fd_runner,
|
||||
model_name_or_path: str,
|
||||
tensor_parallel_size: int,
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
quantization: str,
|
||||
env,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
torch_model_path = get_torch_model_path(model_name_or_path)
|
||||
if env:
|
||||
for k, v in env.items():
|
||||
monkeypatch.setenv(k, v)
|
||||
|
||||
_ = run_with_timeout(
|
||||
target=form_model_get_output_topp0,
|
||||
args=(
|
||||
fd_runner,
|
||||
torch_model_path,
|
||||
tensor_parallel_size,
|
||||
max_model_len,
|
||||
max_tokens,
|
||||
quantization,
|
||||
"default_v1",
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
prompts,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
),
|
||||
)
|
@@ -181,7 +181,7 @@ def check_tokens_id_and_text_close(
|
||||
outputs_1_lst: TokensIdText,
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
warn_on_mismatch: bool = True,
|
||||
threshold: float = 0.0,
|
||||
) -> None:
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
@@ -190,21 +190,28 @@ def check_tokens_id_and_text_close(
|
||||
output_ids_0, output_str_0 = outputs_0
|
||||
output_ids_1, output_str_1 = outputs_1
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
is_tok_mismatch = output_id_0 != output_id_1
|
||||
if is_tok_mismatch and warn_on_mismatch:
|
||||
if threshold > 0:
|
||||
diff_rate = calculate_diff_rate(output_str_0, output_str_1)
|
||||
if diff_rate >= threshold:
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\nMatched tokens:\t{output_ids_0[:idx]}"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}"
|
||||
f"\nDiff rate: {diff_rate:.4f} >= threshold: {threshold}"
|
||||
)
|
||||
raise AssertionError(fail_msg)
|
||||
else:
|
||||
if output_str_0 != output_str_1 and warn_on_mismatch:
|
||||
fail_msg = f"Test{prompt_idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}"
|
||||
raise AssertionError(fail_msg)
|
||||
else:
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
is_tok_mismatch = output_id_0 != output_id_1
|
||||
if is_tok_mismatch:
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\nMatched tokens:\t{output_ids_0[:idx]}"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}"
|
||||
)
|
||||
raise AssertionError(fail_msg)
|
||||
|
||||
|
||||
def calculate_diff_rate(text1, text2):
|
||||
|
Reference in New Issue
Block a user