mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
【Inference Optimize】Update MergedReplicatedLinear for DSK qkv_a_proj_with_mqa. (#3673)
* support MergedReplicatedLinear * update MergedReplicatedLinear to support DSK_wint4 V1_load * update model name * update linear class * fix * fix v0 moe_bias load --------- Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
@@ -298,6 +298,76 @@ class ReplicatedLinear(LinearBase):
|
||||
)
|
||||
|
||||
|
||||
class MergedReplicatedLinear(ReplicatedLinear):
|
||||
"""
|
||||
MergedReplicatedLinear linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_sizes: list[int] = None,
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype: str = "",
|
||||
weight_key: str = "",
|
||||
):
|
||||
"""
|
||||
Initializes a mergedreplicated linear layer.
|
||||
Args:
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_sizes (list[int]): Number of output features list. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
weight_dtype=weight_dtype,
|
||||
weight_key=weight_key,
|
||||
)
|
||||
self.output_sizes = output_sizes
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if model_format == "torch":
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
|
||||
assert loaded_shard_id in ["q_a", "kv_a"]
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
if loaded_shard_id == "q_a":
|
||||
param_shard_offset = 0
|
||||
param_shard_size = self.output_sizes[0]
|
||||
else:
|
||||
# loaded_shard_id == "kv_a"
|
||||
param_shard_offset = self.output_sizes[0]
|
||||
param_shard_size = self.output_sizes[1]
|
||||
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""
|
||||
ColumnParallelLinear Layer.
|
||||
|
Reference in New Issue
Block a user