【Inference Optimize】Update MergedReplicatedLinear for DSK qkv_a_proj_with_mqa. (#3673)

* support MergedReplicatedLinear

* update MergedReplicatedLinear to support DSK_wint4 V1_load

* update model name

* update linear class

* fix

* fix v0 moe_bias load

---------

Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
AIbin
2025-09-05 12:16:05 +08:00
committed by GitHub
parent b23fc654d9
commit 41aee08982
4 changed files with 102 additions and 4 deletions

View File

@@ -298,6 +298,76 @@ class ReplicatedLinear(LinearBase):
)
class MergedReplicatedLinear(ReplicatedLinear):
"""
MergedReplicatedLinear linear layer.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_sizes: list[int] = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a mergedreplicated linear layer.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_sizes (list[int]): Number of output features list. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=sum(output_sizes),
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
)
self.output_sizes = output_sizes
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
loaded_weight = get_tensor(loaded_weight)
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
assert loaded_shard_id in ["q_a", "kv_a"]
if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q_a":
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
else:
# loaded_shard_id == "kv_a"
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
class ColumnParallelLinear(LinearBase):
"""
ColumnParallelLinear Layer.