mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
[v1 loader]qwen Offline fp8 (#4036)
* support offline fp8 * update ut * update ut * update ut * fix * update * update
This commit is contained in:
@@ -57,7 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
|||||||
{
|
{
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -341,10 +341,10 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
|||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
|
|
||||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
|
||||||
if model_format == "torch":
|
if weight_need_transpose:
|
||||||
loaded_weight = loaded_weight.transpose([1, 0])
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
|
|
||||||
assert loaded_shard_id in ["q_a", "kv_a"]
|
assert loaded_shard_id in ["q_a", "kv_a"]
|
||||||
@@ -365,6 +365,12 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
|||||||
assert param.shape == loaded_weight.shape, (
|
assert param.shape == loaded_weight.shape, (
|
||||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||||
)
|
)
|
||||||
|
# Ensure loaded weight dtype matches model param dtype
|
||||||
|
if loaded_weight.dtype != param.dtype:
|
||||||
|
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||||
|
loaded_weight = loaded_weight.view(param.dtype)
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.cast(param.dtype)
|
||||||
param.copy_(loaded_weight, False)
|
param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
|
|
||||||
@@ -483,15 +489,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
if model_format == "torch":
|
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
|
||||||
loaded_weight = loaded_weight.transpose([1, 0])
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
assert output_dim is not None
|
assert output_dim is not None
|
||||||
shard_dim = -1 if output_dim else 0
|
shard_dim = -1 if output_dim else 0
|
||||||
output_size = param.shape[shard_dim]
|
output_size = param.shape[shard_dim]
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
|
if weight_need_transpose:
|
||||||
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
|
param.weight_need_transpose = False
|
||||||
# Loaded weight is already fused on disk.
|
# Loaded weight is already fused on disk.
|
||||||
shard_offsets = [
|
shard_offsets = [
|
||||||
# (shard_id, shard_offset, shard_size)
|
# (shard_id, shard_offset, shard_size)
|
||||||
@@ -506,6 +513,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
else:
|
else:
|
||||||
# split gate up
|
# split gate up
|
||||||
assert loaded_shard_id in ["gate", "up"]
|
assert loaded_shard_id in ["gate", "up"]
|
||||||
|
if weight_need_transpose:
|
||||||
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
# Tensor parallelism splits the weight along the output_dim
|
# Tensor parallelism splits the weight along the output_dim
|
||||||
if self.nranks != 1:
|
if self.nranks != 1:
|
||||||
dim = -1 if output_dim else 0
|
dim = -1 if output_dim else 0
|
||||||
@@ -532,6 +542,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
assert param.shape == loaded_weight.shape, (
|
assert param.shape == loaded_weight.shape, (
|
||||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||||
)
|
)
|
||||||
|
# Ensure loaded weight dtype matches model param dtype
|
||||||
|
if loaded_weight.dtype != param.dtype:
|
||||||
|
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||||
|
loaded_weight = loaded_weight.view(param.dtype)
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.cast(param.dtype)
|
||||||
param.copy_(loaded_weight, False)
|
param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict):
|
def load_state_dict(self, state_dict: dict):
|
||||||
@@ -604,11 +620,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
add_bias=add_bias,
|
add_bias=add_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
|
||||||
shard_size_mapping = {
|
shard_size_mapping = {
|
||||||
"q": self.num_heads_per_rank * self.head_dim,
|
"q": self.num_heads_per_rank * head_dim,
|
||||||
"k": self.kv_num_heads_per_rank * self.head_dim,
|
"k": self.kv_num_heads_per_rank * head_dim,
|
||||||
"v": self.kv_num_heads_per_rank * self.head_dim,
|
"v": self.kv_num_heads_per_rank * head_dim,
|
||||||
}
|
}
|
||||||
return shard_size_mapping.get(loaded_shard_id)
|
return shard_size_mapping.get(loaded_shard_id)
|
||||||
|
|
||||||
@@ -617,11 +633,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
assert output_dim is not None
|
assert output_dim is not None
|
||||||
dim = -1 if output_dim else 0
|
dim = -1 if output_dim else 0
|
||||||
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
if model_format == "torch":
|
if loaded_shard_id is None:
|
||||||
|
if weight_need_transpose:
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
loaded_weight = loaded_weight.transpose([1, 0])
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
if loaded_shard_id is None:
|
param.weight_need_transpose = False
|
||||||
# Loaded weight is already fused on disk
|
# Loaded weight is already fused on disk
|
||||||
shard_offsets = [
|
shard_offsets = [
|
||||||
# (shard_id, shard_offset, shard_size)
|
# (shard_id, shard_offset, shard_size)
|
||||||
@@ -637,13 +654,16 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
else:
|
else:
|
||||||
# split q k v
|
# split q k v
|
||||||
assert loaded_shard_id in ["q", "k", "v"]
|
assert loaded_shard_id in ["q", "k", "v"]
|
||||||
|
if weight_need_transpose:
|
||||||
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
# Tensor parallelism splits the weight along the output_dim
|
# Tensor parallelism splits the weight along the output_dim
|
||||||
if self.nranks != 1:
|
if self.nranks != 1:
|
||||||
block_size = self._get_shard_size_mapping(loaded_shard_id)
|
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||||
shard_offset = shard_id * block_size
|
shard_offset = shard_id * block_size
|
||||||
shard_size = (shard_id + 1) * block_size
|
shard_size = block_size
|
||||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
|
||||||
|
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
|
||||||
@@ -663,10 +683,17 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
param_shard_size = self.kv_num_heads_per_rank * head_dim
|
param_shard_size = self.kv_num_heads_per_rank * head_dim
|
||||||
if hasattr(param, "tensor_track"):
|
if hasattr(param, "tensor_track"):
|
||||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||||
|
|
||||||
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||||
assert param.shape == loaded_weight.shape, (
|
assert param.shape == loaded_weight.shape, (
|
||||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||||
)
|
)
|
||||||
|
# Ensure loaded weight dtype matches model param dtype
|
||||||
|
if loaded_weight.dtype != param.dtype:
|
||||||
|
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||||
|
loaded_weight = loaded_weight.view(param.dtype)
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.cast(param.dtype)
|
||||||
param.copy_(loaded_weight, False)
|
param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
def load_weight(self, state_dict: dict):
|
def load_weight(self, state_dict: dict):
|
||||||
|
@@ -91,7 +91,7 @@ class ParallelLMHead(nn.Layer):
|
|||||||
self.linear.weight,
|
self.linear.weight,
|
||||||
{
|
{
|
||||||
"weight_loader": default_weight_loader(self.fd_config),
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
"model_format": self.fd_config.model_config.model_format,
|
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if self.nranks > 1:
|
if self.nranks > 1:
|
||||||
@@ -110,7 +110,7 @@ class ParallelLMHead(nn.Layer):
|
|||||||
self.linear.weight,
|
self.linear.weight,
|
||||||
{
|
{
|
||||||
"weight_loader": default_weight_loader(self.fd_config),
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
"model_format": self.fd_config.model_config.model_format,
|
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -216,18 +216,17 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
|||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
{
|
{
|
||||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
layer.down_proj_weight,
|
layer.down_proj_weight,
|
||||||
{
|
{
|
||||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@@ -1024,8 +1024,8 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
]
|
]
|
||||||
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
||||||
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
||||||
|
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
layer.up_gate_proj_weight = layer.create_parameter(
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -1037,7 +1037,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
{
|
{
|
||||||
@@ -1097,7 +1097,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||||
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||||
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
|
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
|
||||||
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
|
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
|
||||||
|
@@ -57,7 +57,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||||
]
|
]
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||||
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
layer.up_gate_proj_weight = layer.create_parameter(
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -69,6 +70,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
{
|
{
|
||||||
@@ -127,6 +129,25 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||||
|
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, up_gate_proj_weight_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, up_gate_proj_scale_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, down_proj_weight_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, down_proj_scale_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
""" """
|
""" """
|
||||||
@@ -169,6 +190,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
|
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
|
||||||
)
|
)
|
||||||
weight[expert_id].copy_(weight_quant, False)
|
weight[expert_id].copy_(weight_quant, False)
|
||||||
|
|
||||||
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
|
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
|
||||||
|
|
||||||
# create weight
|
# create weight
|
||||||
|
@@ -72,7 +72,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
layer.moe_intermediate_size,
|
layer.moe_intermediate_size,
|
||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
]
|
]
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||||
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
layer.up_gate_proj_weight = layer.create_parameter(
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
shape=self.up_gate_proj_weight_shape,
|
shape=self.up_gate_proj_weight_shape,
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -84,6 +85,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
|
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
{
|
{
|
||||||
@@ -136,6 +139,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# support cache feature in future
|
||||||
|
|
||||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
"""
|
||||||
@@ -723,7 +727,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||||
]
|
]
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||||
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
layer.up_gate_proj_weight = layer.create_parameter(
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -735,6 +740,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
{
|
{
|
||||||
@@ -794,6 +800,26 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||||
|
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, up_gate_proj_weight_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, up_gate_proj_scale_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, down_proj_weight_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
getattr(layer, down_proj_scale_name),
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
""" """
|
""" """
|
||||||
if not self.quant_config.is_checkpoint_bf16:
|
if not self.quant_config.is_checkpoint_bf16:
|
||||||
|
@@ -206,20 +206,19 @@ class FusedMoE(nn.Layer):
|
|||||||
|
|
||||||
if shard_id is None:
|
if shard_id is None:
|
||||||
# 1.gate up fused in disk
|
# 1.gate up fused in disk
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
is_torch_model = model_format == "torch"
|
|
||||||
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
|
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
|
||||||
per_rank = output_size // 2
|
per_rank = output_size // 2
|
||||||
start = self.tp_rank * per_rank
|
start = self.tp_rank * per_rank
|
||||||
loaded_weight_shard_gate = slice_fn(
|
loaded_weight_shard_gate = slice_fn(
|
||||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
|
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
|
||||||
)
|
)
|
||||||
self._load_gate_up_weight(
|
self._load_gate_up_weight(
|
||||||
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
|
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
|
||||||
)
|
)
|
||||||
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
|
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
|
||||||
loaded_weight_shard_up = slice_fn(
|
loaded_weight_shard_up = slice_fn(
|
||||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
|
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
|
||||||
)
|
)
|
||||||
self._load_gate_up_weight(
|
self._load_gate_up_weight(
|
||||||
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
|
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
|
||||||
@@ -236,10 +235,9 @@ class FusedMoE(nn.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
|
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
is_torch_model = model_format == "torch"
|
|
||||||
if self.tp_size > 1 and not is_sharded:
|
if self.tp_size > 1 and not is_sharded:
|
||||||
tp_shard_dim = is_torch_model ^ shard_dim
|
tp_shard_dim = weight_need_transpose ^ shard_dim
|
||||||
weight_dim = -1 if tp_shard_dim else 0
|
weight_dim = -1 if tp_shard_dim else 0
|
||||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||||
size = loaded_weight.shape[weight_dim]
|
size = loaded_weight.shape[weight_dim]
|
||||||
@@ -275,13 +273,17 @@ class FusedMoE(nn.Layer):
|
|||||||
assert expert_param.shape == loaded_weight.shape, (
|
assert expert_param.shape == loaded_weight.shape, (
|
||||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||||
)
|
)
|
||||||
|
if expert_param.dtype != loaded_weight.dtype:
|
||||||
|
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
|
||||||
|
loaded_weight = loaded_weight.view(expert_param.dtype)
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.cast(expert_param.dtype)
|
||||||
expert_param.copy_(loaded_weight, False)
|
expert_param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
is_torch_model = model_format == "torch"
|
|
||||||
if self.tp_size > 1 and shard_dim is not None:
|
if self.tp_size > 1 and shard_dim is not None:
|
||||||
tp_shard_dim = is_torch_model ^ shard_dim
|
tp_shard_dim = weight_need_transpose ^ shard_dim
|
||||||
dim = -1 if tp_shard_dim else 0
|
dim = -1 if tp_shard_dim else 0
|
||||||
if isinstance(loaded_weight, paddle.Tensor):
|
if isinstance(loaded_weight, paddle.Tensor):
|
||||||
size = loaded_weight.shape[dim]
|
size = loaded_weight.shape[dim]
|
||||||
@@ -302,6 +304,11 @@ class FusedMoE(nn.Layer):
|
|||||||
assert expert_param.shape == loaded_weight.shape, (
|
assert expert_param.shape == loaded_weight.shape, (
|
||||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||||
)
|
)
|
||||||
|
if expert_param.dtype != loaded_weight.dtype:
|
||||||
|
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
|
||||||
|
loaded_weight = loaded_weight.view(expert_param.dtype)
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.cast(expert_param.dtype)
|
||||||
expert_param.copy_(loaded_weight, False)
|
expert_param.copy_(loaded_weight, False)
|
||||||
|
|
||||||
def _load_expert_weight(
|
def _load_expert_weight(
|
||||||
|
@@ -34,6 +34,72 @@ QUANTIZATION_METHODS: List[str] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
||||||
|
# 1.model_config.is_quantized
|
||||||
|
# TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future
|
||||||
|
if model_config.model_format == "torch":
|
||||||
|
quantization_config = model_config.quantization_config
|
||||||
|
if quantization_config is not None:
|
||||||
|
model_config.is_quantized = True
|
||||||
|
else:
|
||||||
|
quantization_config = model_config.quantization_config
|
||||||
|
if not model_config.is_quantized:
|
||||||
|
if quantization_config is not None:
|
||||||
|
if "is_quantized" in quantization_config:
|
||||||
|
model_config.is_quantized = quantization_config["is_quantized"]
|
||||||
|
elif "kv_cache_quant_type" not in quantization_config:
|
||||||
|
model_config.is_quantized = True
|
||||||
|
if quantization_config is not None and quantization_config.get("quantization", None) is None:
|
||||||
|
raise ValueError(
|
||||||
|
"quantization_config should have a key named 'quantization' for specify quant config."
|
||||||
|
)
|
||||||
|
|
||||||
|
quant_config_name = None
|
||||||
|
|
||||||
|
if quantization_config is not None:
|
||||||
|
quant_config_name = _get_offline_quant_config_name(
|
||||||
|
quantization_config, model_config.model_format == "torch", is_v1_loader
|
||||||
|
)
|
||||||
|
elif args.quantization is not None:
|
||||||
|
quantization_config = {}
|
||||||
|
try:
|
||||||
|
quantization_config.update(args.quantization)
|
||||||
|
quant_config_name = quantization_config["quantization"]
|
||||||
|
except:
|
||||||
|
quant_config_name = args.quantization["quantization"]
|
||||||
|
quantization_config["quantization"] = quant_config_name
|
||||||
|
# Special handling for Ernie models
|
||||||
|
if quant_config_name == "wint4" and is_ernie:
|
||||||
|
quantization_config["dense_quant_type"] = "wint8"
|
||||||
|
quantization_config["moe_quant_type"] = "wint4"
|
||||||
|
quantization_config["quantization"] = "mix_quant"
|
||||||
|
quant_config_name = "mix_quant"
|
||||||
|
else:
|
||||||
|
quant_config_name = None
|
||||||
|
if quant_config_name is None:
|
||||||
|
quant_config = None
|
||||||
|
else:
|
||||||
|
if not quantization_config.get("is_quantized"):
|
||||||
|
quantization_config["is_quantized"] = model_config.is_quantized
|
||||||
|
quant_cls = get_quantization_config(quant_config_name)
|
||||||
|
quant_config = quant_cls.from_config(quantization_config)
|
||||||
|
return quant_config
|
||||||
|
|
||||||
|
|
||||||
|
def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_loader):
|
||||||
|
if is_torch_weight:
|
||||||
|
# only support block_wise_fp8 now
|
||||||
|
quant_method = quantization_config.get("quant_method")
|
||||||
|
has_block_size = "weight_block_size" in quantization_config
|
||||||
|
if quant_method == "fp8" and has_block_size:
|
||||||
|
quant_config_name = "block_wise_fp8"
|
||||||
|
else:
|
||||||
|
raise ValueError("Torch weight offline quantization only supports block-wise FP8.")
|
||||||
|
else:
|
||||||
|
quant_config_name = quantization_config["quantization"]
|
||||||
|
return quant_config_name
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
|
def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
|
||||||
"""
|
"""
|
||||||
Get the quantization config class by the quantization name.
|
Get the quantization config class by the quantization name.
|
||||||
|
@@ -53,7 +53,7 @@ class BlockWiseFP8Config(QuantConfigBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
|
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
|
||||||
weight_block_size = config.get("weight_block_size", [128, 128])
|
weight_block_size = config.get("weight_block_size", [128, 128])
|
||||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||||
return cls(weight_block_size, is_checkpoint_bf16)
|
return cls(weight_block_size, is_checkpoint_bf16)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
@@ -89,13 +89,15 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer, **extra_weight_attrs):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||||
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
layer.weight = layer.create_parameter(
|
layer.weight = layer.create_parameter(
|
||||||
shape=layer.weight_shape,
|
shape=layer.weight_shape,
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
quant_attrs = extra_weight_attrs
|
quant_attrs = extra_weight_attrs
|
||||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||||
quant_attrs = {
|
quant_attrs = {
|
||||||
@@ -120,14 +122,28 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
|
|
||||||
layer.weight_scale_inv = layer.create_parameter(
|
layer.weight_scale_inv = layer.create_parameter(
|
||||||
shape=[
|
shape=[
|
||||||
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
(layer.weight_shape[0] + self.quant_config.weight_block_size[0] - 1)
|
||||||
// self.quant_config.weight_block_size[0],
|
// self.quant_config.weight_block_size[0],
|
||||||
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
|
(layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1)
|
||||||
// self.quant_config.weight_block_size[1],
|
// self.quant_config.weight_block_size[1],
|
||||||
],
|
],
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
|
||||||
|
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.weight,
|
||||||
|
extra_weight_attrs,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.weight_scale_inv,
|
||||||
|
{
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"is_scale": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer) -> None:
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
if not self.quant_config.is_checkpoint_bf16:
|
if not self.quant_config.is_checkpoint_bf16:
|
||||||
|
@@ -37,7 +37,7 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
is_channel_wise: bool = False,
|
is_channel_wise: bool = False,
|
||||||
has_zero_point: bool = False,
|
has_zero_point: bool = False,
|
||||||
is_permuted: bool = True,
|
is_permuted: bool = True,
|
||||||
is_checkpoint_bf16: bool = False,
|
is_quantized: bool = False,
|
||||||
hadamard_block_size: int = 128,
|
hadamard_block_size: int = 128,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -54,7 +54,8 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
self.quant_min_bound = 0
|
self.quant_min_bound = 0
|
||||||
self.quant_round_type = 0
|
self.quant_round_type = 0
|
||||||
self.is_permuted = is_permuted
|
self.is_permuted = is_permuted
|
||||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
self.is_checkpoint_bf16 = not is_quantized
|
||||||
|
self.is_quantized = is_quantized
|
||||||
self.hadamard_block_size = hadamard_block_size
|
self.hadamard_block_size = hadamard_block_size
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -70,7 +71,7 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
config.get("is_channel_wise", False),
|
config.get("is_channel_wise", False),
|
||||||
config.get("has_zero_point", False),
|
config.get("has_zero_point", False),
|
||||||
config.get("is_permuted", True),
|
config.get("is_permuted", True),
|
||||||
config.get("is_checkpoint_bf16", False),
|
config.get("is_quantized", False),
|
||||||
config.get("hadamard_block_size", 128),
|
config.get("hadamard_block_size", 128),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,7 +83,7 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
.from_config(
|
.from_config(
|
||||||
{
|
{
|
||||||
"is_permuted": self.is_permuted,
|
"is_permuted": self.is_permuted,
|
||||||
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
"is_quantized": self.is_quantized,
|
||||||
"hadamard_block_size": self.hadamard_block_size,
|
"hadamard_block_size": self.hadamard_block_size,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -94,7 +95,7 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
.from_config(
|
.from_config(
|
||||||
{
|
{
|
||||||
"is_permuted": self.is_permuted,
|
"is_permuted": self.is_permuted,
|
||||||
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
"is_quantized": self.is_quantized,
|
||||||
"hadamard_block_size": self.hadamard_block_size,
|
"hadamard_block_size": self.hadamard_block_size,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -112,6 +113,6 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
get_quantization_config(self.dense_quant_type)
|
get_quantization_config(self.dense_quant_type)
|
||||||
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
|
.from_config({"is_quantized": self.is_quantized})
|
||||||
.get_quant_method(layer)
|
.get_quant_method(layer)
|
||||||
)
|
)
|
||||||
|
@@ -65,7 +65,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
||||||
algo = config["algo"]
|
algo = config["algo"]
|
||||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||||
return cls(algo, is_checkpoint_bf16)
|
return cls(algo, is_checkpoint_bf16)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
@@ -162,7 +162,7 @@ class WINT8Config(WeightOnlyConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WINT8Config":
|
def from_config(cls, config: dict) -> "WINT8Config":
|
||||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||||
return cls(is_checkpoint_bf16)
|
return cls(is_checkpoint_bf16)
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -182,7 +182,7 @@ class WINT4Config(WeightOnlyConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WINT4Config":
|
def from_config(cls, config: dict) -> "WINT4Config":
|
||||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||||
return cls(is_checkpoint_bf16)
|
return cls(is_checkpoint_bf16)
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -202,13 +202,15 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer, **extra_weight_attrs):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||||
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
layer.weight = layer.create_parameter(
|
layer.weight = layer.create_parameter(
|
||||||
shape=layer.weight_shape,
|
shape=layer.weight_shape,
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
quant_attrs = extra_weight_attrs
|
quant_attrs = extra_weight_attrs
|
||||||
if (
|
if (
|
||||||
isinstance(layer, MergedColumnParallelLinear)
|
isinstance(layer, MergedColumnParallelLinear)
|
||||||
@@ -256,6 +258,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
{
|
{
|
||||||
"weight_loader": weight_loader,
|
"weight_loader": weight_loader,
|
||||||
"output_dim": output_dim,
|
"output_dim": output_dim,
|
||||||
|
"weight_need_transpose": not extra_weight_attrs.get("model_format") == "torch",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -60,7 +60,7 @@ class WFP8AFP8Config(QuantConfigBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
||||||
""" """
|
""" """
|
||||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
is_checkpoint_bf16 = not config.get("is_quantized", False)
|
||||||
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
|
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
@@ -92,13 +92,14 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
|||||||
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
|
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
|
||||||
)
|
)
|
||||||
scale_shape = scale_shape[::-1]
|
scale_shape = scale_shape[::-1]
|
||||||
if self.quant_config.is_checkpoint_bf16:
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
layer.weight = layer.create_parameter(
|
layer.weight = layer.create_parameter(
|
||||||
shape=weight_shape,
|
shape=weight_shape,
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
quant_attrs = extra_weight_attrs
|
quant_attrs = extra_weight_attrs
|
||||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||||
quant_attrs = {
|
quant_attrs = {
|
||||||
|
@@ -98,7 +98,7 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
|||||||
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
|
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
|
||||||
)
|
)
|
||||||
enable_cache = True
|
enable_cache = True
|
||||||
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
|
weight_cache_context = switch_config_context(fd_config.quant_config, "is_quantized", True)
|
||||||
|
|
||||||
return enable_cache, weight_cache_dir, weight_cache_context
|
return enable_cache, weight_cache_dir, weight_cache_context
|
||||||
|
|
||||||
@@ -150,7 +150,8 @@ def save_model(model_arg_name="model", config_arg_name="fd_config"):
|
|||||||
)
|
)
|
||||||
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
|
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
|
||||||
else:
|
else:
|
||||||
logger.info("Weights are already cached, skip saving")
|
reason = "weights already cached" if envs.FD_ENABLE_MODEL_LOAD_CACHE else "cache disabled"
|
||||||
|
logger.info(f"Skip saving ,{reason}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@@ -527,6 +527,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
from fastdeploy.model_executor.utils import (
|
from fastdeploy.model_executor.utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
process_weights_after_loading,
|
process_weights_after_loading,
|
||||||
|
rename_offline_ckpt_suffix_to_fd_suffix,
|
||||||
)
|
)
|
||||||
|
|
||||||
general_params_mapping = [
|
general_params_mapping = [
|
||||||
@@ -564,15 +565,20 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
param_down_proj_name="experts.down_proj_",
|
param_down_proj_name="experts.down_proj_",
|
||||||
num_experts_start_offset=num_experts_start_offset,
|
num_experts_start_offset=num_experts_start_offset,
|
||||||
)
|
)
|
||||||
all_param_mapping = general_params_mapping + expert_params_mapping
|
all_param_mapping = [
|
||||||
|
(param, weight, exp, shard, False) for param, weight, exp, shard in general_params_mapping
|
||||||
|
] + [(param, weight, exp, shard, True) for param, weight, exp, shard in expert_params_mapping]
|
||||||
|
checkpoint_to_fd_key_fn = rename_offline_ckpt_suffix_to_fd_suffix(
|
||||||
|
fd_config=self.fd_config, ckpt_weight_suffix="quant_weight", ckpt_scale_suffix="weight_scale"
|
||||||
|
)
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||||
|
|
||||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
loaded_weight_name = loaded_weight_name.replace("model", "ernie")
|
loaded_weight_name = loaded_weight_name.replace("model", "ernie")
|
||||||
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
for param_name, weight_name, exp_id, shard_id, is_moe in all_param_mapping:
|
||||||
|
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe)
|
||||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||||
if model_param_name not in params_dict:
|
if model_param_name not in params_dict:
|
||||||
continue
|
continue
|
||||||
@@ -583,6 +589,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
else:
|
else:
|
||||||
expert_id = None
|
expert_id = None
|
||||||
shard_id = None
|
shard_id = None
|
||||||
|
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe=False)
|
||||||
model_param_name = loaded_weight_name
|
model_param_name = loaded_weight_name
|
||||||
if model_param_name not in params_dict.keys():
|
if model_param_name not in params_dict.keys():
|
||||||
continue
|
continue
|
||||||
|
@@ -193,16 +193,16 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
set_weight_attrs(self.qkv.weight, {"model_format": model_format})
|
set_weight_attrs(self.qkv.weight, {"weight_need_transpose": model_format == "torch"})
|
||||||
set_weight_attrs(self.proj.weight, {"model_format": model_format})
|
set_weight_attrs(self.proj.weight, {"weight_need_transpose": model_format == "torch"})
|
||||||
self.head_dim = dim // num_heads # must added
|
self.head_dim = dim // num_heads # must added
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = dim
|
self.hidden_size = dim
|
||||||
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
||||||
|
|
||||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
if model_format == "torch":
|
if weight_need_transpose:
|
||||||
loaded_weight = loaded_weight.transpose([1, 0])
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
load_bias = getattr(param, "load_bias", None)
|
load_bias = getattr(param, "load_bias", None)
|
||||||
if load_bias:
|
if load_bias:
|
||||||
@@ -358,8 +358,8 @@ class VisionMlp(nn.Layer):
|
|||||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||||
|
|
||||||
set_weight_attrs(self.fc1.weight, {"model_format": model_format})
|
set_weight_attrs(self.fc1.weight, {"weight_need_transpose": model_format == "torch"})
|
||||||
set_weight_attrs(self.fc2.weight, {"model_format": model_format})
|
set_weight_attrs(self.fc2.weight, {"weight_need_transpose": model_format == "torch"})
|
||||||
|
|
||||||
self.act = ACT2FN[hidden_act]
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
@@ -528,8 +528,10 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
|||||||
in_channels=config.vision_config.in_channels,
|
in_channels=config.vision_config.in_channels,
|
||||||
embed_dim=config.vision_config.embed_dim,
|
embed_dim=config.vision_config.embed_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_format = getattr(config, "model_format", "")
|
model_format = getattr(config, "model_format", "")
|
||||||
set_weight_attrs(self.patch_embed.proj.weight, {"model_format": model_format})
|
|
||||||
|
set_weight_attrs(self.patch_embed.proj.weight, {"weight_need_transpose": model_format == "torch"})
|
||||||
|
|
||||||
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
|
head_dim = config.vision_config.embed_dim // config.vision_config.num_heads
|
||||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
@@ -181,8 +181,8 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
nn.Linear(self.spatial_dim, self.spatial_dim),
|
nn.Linear(self.spatial_dim, self.spatial_dim),
|
||||||
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
||||||
)
|
)
|
||||||
set_weight_attrs(self.spatial_linear[0].weight, {"model_format": config.model_format})
|
set_weight_attrs(self.spatial_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"})
|
||||||
set_weight_attrs(self.spatial_linear[2].weight, {"model_format": config.model_format})
|
set_weight_attrs(self.spatial_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"})
|
||||||
|
|
||||||
if self.use_temporal_conv:
|
if self.use_temporal_conv:
|
||||||
self.temporal_linear = nn.Sequential(
|
self.temporal_linear = nn.Sequential(
|
||||||
@@ -191,12 +191,16 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
nn.Linear(self.spatial_dim, self.spatial_dim),
|
nn.Linear(self.spatial_dim, self.spatial_dim),
|
||||||
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
nn.LayerNorm(self.spatial_dim, epsilon=1e-6),
|
||||||
)
|
)
|
||||||
set_weight_attrs(self.temporal_linear[0].weight, {"model_format": config.model_format})
|
set_weight_attrs(
|
||||||
set_weight_attrs(self.temporal_linear[2].weight, {"model_format": config.model_format})
|
self.temporal_linear[0].weight, {"weight_need_transpose": config.model_format == "torch"}
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.temporal_linear[2].weight, {"weight_need_transpose": config.model_format == "torch"}
|
||||||
|
)
|
||||||
|
|
||||||
self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
|
self.mlp = nn.Linear(self.spatial_dim, self.out_dim)
|
||||||
|
|
||||||
set_weight_attrs(self.mlp.weight, {"model_format": config.model_format})
|
set_weight_attrs(self.mlp.weight, {"weight_need_transpose": config.model_format == "torch"})
|
||||||
|
|
||||||
out_config = deepcopy(config)
|
out_config = deepcopy(config)
|
||||||
out_config.hidden_size = out_dim
|
out_config.hidden_size = out_dim
|
||||||
|
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@@ -158,8 +159,8 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
|||||||
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
|
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
|
||||||
"""fn"""
|
"""fn"""
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
model_format = getattr(param, "model_format", "")
|
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||||
if model_format == "torch":
|
if weight_need_transpose:
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
loaded_weight = loaded_weight.transpose([1, 0])
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
# Tensor parallelism splits the weight along the output_dim
|
# Tensor parallelism splits the weight along the output_dim
|
||||||
@@ -177,6 +178,9 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
|||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||||
if param.dtype != loaded_weight.dtype:
|
if param.dtype != loaded_weight.dtype:
|
||||||
|
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||||
|
loaded_weight = loaded_weight.view(param.dtype)
|
||||||
|
else:
|
||||||
loaded_weight = loaded_weight.cast(param.dtype)
|
loaded_weight = loaded_weight.cast(param.dtype)
|
||||||
if param.shape != loaded_weight.shape:
|
if param.shape != loaded_weight.shape:
|
||||||
# for e_score_correction_bias
|
# for e_score_correction_bias
|
||||||
@@ -210,3 +214,50 @@ def switch_config_context(config_obj, config_attr_name, value):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
setattr(config_obj, config_attr_name, origin_value)
|
setattr(config_obj, config_attr_name, origin_value)
|
||||||
|
|
||||||
|
|
||||||
|
def rename_offline_ckpt_suffix_to_fd_suffix(
|
||||||
|
fd_config, ckpt_weight_suffix: str = "quant_weight", ckpt_scale_suffix="weight_scale"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a function to rename checkpoint key suffixes for FastDeploy.
|
||||||
|
|
||||||
|
Replaces the original suffix (default "weight_scale") with the FD target
|
||||||
|
suffix (default "quant_weight"). Only the suffix is changed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fd_config: FastDeploy configuration.
|
||||||
|
ckpt_weight_suffix: Original checkpoint key suffix.
|
||||||
|
ckpt_scale_suffix: Target FastDeploy key suffix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: Function that renames checkpoint keys.
|
||||||
|
"""
|
||||||
|
fd_suffix_map = {} # noqa: F841
|
||||||
|
fp8_suffix_map = {
|
||||||
|
ckpt_weight_suffix: "weight",
|
||||||
|
ckpt_scale_suffix: "weight_scale_inv",
|
||||||
|
}
|
||||||
|
moe_quant_type = ""
|
||||||
|
dense_quant_type = ""
|
||||||
|
if fd_config.quant_config is None:
|
||||||
|
if fd_config.quant_config.name() == "mix_quant":
|
||||||
|
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||||
|
dense_quant_type = fd_config.quant_config.dense_quant_type
|
||||||
|
else:
|
||||||
|
moe_quant_type = fd_config.quant_config.name()
|
||||||
|
dense_quant_type = fd_config.quant_config.name()
|
||||||
|
|
||||||
|
def fn(loaded_weight_name, is_moe):
|
||||||
|
if fd_config.quant_config is None or fd_config.quant_config.is_checkpoint_bf16:
|
||||||
|
return loaded_weight_name
|
||||||
|
# Can be extended to other offline quantization suffixes if needed.
|
||||||
|
if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"):
|
||||||
|
fd_suffix_map = fp8_suffix_map
|
||||||
|
for ckpt_suffix, fd_suffix in fd_suffix_map.items():
|
||||||
|
if re.search(rf"{ckpt_suffix}$", loaded_weight_name):
|
||||||
|
loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix)
|
||||||
|
return loaded_weight_name
|
||||||
|
return loaded_weight_name
|
||||||
|
|
||||||
|
return fn
|
||||||
|
@@ -42,7 +42,7 @@ from fastdeploy.config import (
|
|||||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||||
from fastdeploy.inter_communicator import IPCSignal
|
from fastdeploy.inter_communicator import IPCSignal
|
||||||
from fastdeploy.model_executor.layers.quantization import get_quantization_config
|
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import get_logger, parse_quantization
|
from fastdeploy.utils import get_logger, parse_quantization
|
||||||
from fastdeploy.worker.worker_base import WorkerBase
|
from fastdeploy.worker.worker_base import WorkerBase
|
||||||
@@ -698,50 +698,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
if getattr(model_config, "num_hidden_layers", None) is None:
|
if getattr(model_config, "num_hidden_layers", None) is None:
|
||||||
raise ValueError("num_hidden_layers is None")
|
raise ValueError("num_hidden_layers is None")
|
||||||
|
|
||||||
quantization_config = model_config.quantization_config
|
quant_config = parse_quant_config(
|
||||||
if not model_config.is_quantized:
|
args,
|
||||||
if quantization_config is not None:
|
model_config,
|
||||||
if "is_quantized" in quantization_config:
|
is_ernie=ErnieArchitectures.contains_ernie_arch(model_config.architectures),
|
||||||
model_config.is_quantized = quantization_config["is_quantized"]
|
is_v1_loader=load_config.load_choices == "default_v1",
|
||||||
elif "kv_cache_quant_type" not in quantization_config:
|
)
|
||||||
model_config.is_quantized = True
|
|
||||||
|
|
||||||
quant_config_name = None
|
|
||||||
if quantization_config is not None and quantization_config.get("quantization", None) is None:
|
|
||||||
raise ValueError("quantization_config should have a key named 'quantization' for specify quant config.")
|
|
||||||
|
|
||||||
if quantization_config is not None:
|
|
||||||
quant_config_name = quantization_config["quantization"]
|
|
||||||
# TODO(YuanRisheng) is_checkpoint_bf16 may need to be removed and replaced by is_quantized in future
|
|
||||||
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
|
|
||||||
quantization_config["is_checkpoint_bf16"] = True
|
|
||||||
|
|
||||||
elif args.quantization is not None:
|
|
||||||
quantization_config = {}
|
|
||||||
try:
|
|
||||||
quantization_config.update(args.quantization)
|
|
||||||
quant_config_name = quantization_config["quantization"]
|
|
||||||
except:
|
|
||||||
quant_config_name = args.quantization["quantization"]
|
|
||||||
quantization_config["quantization"] = quant_config_name
|
|
||||||
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
|
|
||||||
if load_config.load_choices == "default_v1":
|
|
||||||
quantization_config["is_checkpoint_bf16"] = True
|
|
||||||
# Special handling for Ernie models
|
|
||||||
is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)
|
|
||||||
if quant_config_name == "wint4" and is_ernie:
|
|
||||||
quantization_config["dense_quant_type"] = "wint8"
|
|
||||||
quantization_config["moe_quant_type"] = "wint4"
|
|
||||||
quantization_config["quantization"] = "mix_quant"
|
|
||||||
quant_config_name = "mix_quant"
|
|
||||||
else:
|
|
||||||
quant_config_name = None
|
|
||||||
|
|
||||||
if quant_config_name is None:
|
|
||||||
quant_config = None
|
|
||||||
else:
|
|
||||||
quant_cls = get_quantization_config(quant_config_name)
|
|
||||||
quant_config = quant_cls.from_config(quantization_config)
|
|
||||||
|
|
||||||
# Log quantization info
|
# Log quantization info
|
||||||
logger.info("===========quantization_config==============")
|
logger.info("===========quantization_config==============")
|
||||||
@@ -751,7 +713,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
else:
|
else:
|
||||||
logger.info("Model Status: Original (will apply online quantization)")
|
logger.info("Model Status: Original (will apply online quantization)")
|
||||||
|
|
||||||
logger.info(f"{quantization_config}")
|
logger.info(f"{model_config.quantization_config}")
|
||||||
else:
|
else:
|
||||||
logger.info("No quantization config found and use original weight and act dtype.")
|
logger.info("No quantization config found and use original weight and act dtype.")
|
||||||
|
|
||||||
|
@@ -53,12 +53,8 @@ class FDRunner:
|
|||||||
|
|
||||||
req_outputs = self.llm.generate(prompts, sampling_params=sampling_params, **kwargs)
|
req_outputs = self.llm.generate(prompts, sampling_params=sampling_params, **kwargs)
|
||||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
sample_output_ids: list[list[int]] = []
|
|
||||||
sample_output_strs: list[str] = []
|
|
||||||
for output in req_outputs:
|
for output in req_outputs:
|
||||||
sample_output_ids.append(output.outputs.token_ids)
|
outputs.append((output.outputs.token_ids, output.outputs.text))
|
||||||
sample_output_strs.append(output.outputs.text)
|
|
||||||
outputs.append((sample_output_ids, sample_output_strs))
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def generate_topp0(
|
def generate_topp0(
|
||||||
@@ -69,7 +65,7 @@ class FDRunner:
|
|||||||
) -> list[tuple[list[int], str]]:
|
) -> list[tuple[list[int], str]]:
|
||||||
from fastdeploy.engine.sampling_params import SamplingParams
|
from fastdeploy.engine.sampling_params import SamplingParams
|
||||||
|
|
||||||
topp_params = SamplingParams(temperature=0.1, top_p=0, max_tokens=max_tokens)
|
topp_params = SamplingParams(temperature=0.0, top_p=0, max_tokens=max_tokens)
|
||||||
outputs = self.generate(prompts, topp_params, **kwargs)
|
outputs = self.generate(prompts, topp_params, **kwargs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
104
tests/model_loader/test_offline_model.py
Normal file
104
tests/model_loader/test_offline_model.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
prompts = ["解释下'温故而知新'", "who are you?"]
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, ".."))
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
from tests.model_loader.utils import (
|
||||||
|
form_model_get_output_topp0,
|
||||||
|
get_torch_model_path,
|
||||||
|
run_with_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
|
||||||
|
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||||
|
|
||||||
|
|
||||||
|
model_param_map = {
|
||||||
|
"Qwen3-30B-A3B-FP8": {
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
"quantizations": [
|
||||||
|
{
|
||||||
|
"quant_type": "None",
|
||||||
|
"backend": "triton",
|
||||||
|
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
params = []
|
||||||
|
for model, cfg in model_param_map.items():
|
||||||
|
for q in cfg["quantizations"]:
|
||||||
|
if isinstance(q, dict):
|
||||||
|
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
|
||||||
|
else:
|
||||||
|
quant, backend, env = q, "default", {}
|
||||||
|
params.append(
|
||||||
|
pytest.param(
|
||||||
|
model,
|
||||||
|
cfg.get("tensor_parallel_size", 1),
|
||||||
|
cfg.get("max_model_len", 1024),
|
||||||
|
quant,
|
||||||
|
cfg.get("max_tokens", 32),
|
||||||
|
env,
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
id=f"offline_quant_{model}.{quant}.{backend}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
def test_offline_model(
|
||||||
|
fd_runner,
|
||||||
|
model_name_or_path: str,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
max_model_len: int,
|
||||||
|
max_tokens: int,
|
||||||
|
quantization: str,
|
||||||
|
env,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
torch_model_path = get_torch_model_path(model_name_or_path)
|
||||||
|
if env:
|
||||||
|
for k, v in env.items():
|
||||||
|
monkeypatch.setenv(k, v)
|
||||||
|
|
||||||
|
_ = run_with_timeout(
|
||||||
|
target=form_model_get_output_topp0,
|
||||||
|
args=(
|
||||||
|
fd_runner,
|
||||||
|
torch_model_path,
|
||||||
|
tensor_parallel_size,
|
||||||
|
max_model_len,
|
||||||
|
max_tokens,
|
||||||
|
quantization,
|
||||||
|
"default_v1",
|
||||||
|
FD_ENGINE_QUEUE_PORT,
|
||||||
|
prompts,
|
||||||
|
FD_CACHE_QUEUE_PORT,
|
||||||
|
),
|
||||||
|
)
|
@@ -181,7 +181,7 @@ def check_tokens_id_and_text_close(
|
|||||||
outputs_1_lst: TokensIdText,
|
outputs_1_lst: TokensIdText,
|
||||||
name_0: str,
|
name_0: str,
|
||||||
name_1: str,
|
name_1: str,
|
||||||
warn_on_mismatch: bool = True,
|
threshold: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||||
|
|
||||||
@@ -190,10 +190,21 @@ def check_tokens_id_and_text_close(
|
|||||||
output_ids_0, output_str_0 = outputs_0
|
output_ids_0, output_str_0 = outputs_0
|
||||||
output_ids_1, output_str_1 = outputs_1
|
output_ids_1, output_str_1 = outputs_1
|
||||||
|
|
||||||
|
if threshold > 0:
|
||||||
|
diff_rate = calculate_diff_rate(output_str_0, output_str_1)
|
||||||
|
if diff_rate >= threshold:
|
||||||
|
fail_msg = (
|
||||||
|
f"Test{prompt_idx}:"
|
||||||
|
f"\n{name_0}:\t{output_str_0!r}"
|
||||||
|
f"\n{name_1}:\t{output_str_1!r}"
|
||||||
|
f"\nDiff rate: {diff_rate:.4f} >= threshold: {threshold}"
|
||||||
|
)
|
||||||
|
raise AssertionError(fail_msg)
|
||||||
|
else:
|
||||||
# Loop through generated tokens.
|
# Loop through generated tokens.
|
||||||
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||||
is_tok_mismatch = output_id_0 != output_id_1
|
is_tok_mismatch = output_id_0 != output_id_1
|
||||||
if is_tok_mismatch and warn_on_mismatch:
|
if is_tok_mismatch:
|
||||||
fail_msg = (
|
fail_msg = (
|
||||||
f"Test{prompt_idx}:"
|
f"Test{prompt_idx}:"
|
||||||
f"\nMatched tokens:\t{output_ids_0[:idx]}"
|
f"\nMatched tokens:\t{output_ids_0[:idx]}"
|
||||||
@@ -201,10 +212,6 @@ def check_tokens_id_and_text_close(
|
|||||||
f"\n{name_1}:\t{output_str_1!r}"
|
f"\n{name_1}:\t{output_str_1!r}"
|
||||||
)
|
)
|
||||||
raise AssertionError(fail_msg)
|
raise AssertionError(fail_msg)
|
||||||
else:
|
|
||||||
if output_str_0 != output_str_1 and warn_on_mismatch:
|
|
||||||
fail_msg = f"Test{prompt_idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}"
|
|
||||||
raise AssertionError(fail_msg)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_diff_rate(text1, text2):
|
def calculate_diff_rate(text1, text2):
|
||||||
|
Reference in New Issue
Block a user