diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 2c7f9aef3..8ce9fb122 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -298,6 +298,76 @@ class ReplicatedLinear(LinearBase): ) +class MergedReplicatedLinear(ReplicatedLinear): + """ + MergedReplicatedLinear linear layer. + """ + + def __init__( + self, + fd_config: FDConfig, + prefix: str = "", + input_size: int = None, + output_sizes: list[int] = None, + with_bias: bool = False, + add_bias: bool = False, + skip_quant: bool = False, + weight_dtype: str = "", + weight_key: str = "", + ): + """ + Initializes a mergedreplicated linear layer. + Args: + fd_config (FDConfig): Inference-related parameters. + prefix (str): Unique name of the layer, used to name internal attributes. + Can be arbitrarily named. + input_size (int): Number of input features. Defaults to None. + output_sizes (list[int]): Number of output features list. Defaults to None. + with_bias (bool): Whether to include bias or not. Defaults to False. + add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False. + skip_quant (bool): Whether to skip quantization. Defaults to False. + """ + super().__init__( + fd_config=fd_config, + prefix=prefix, + input_size=input_size, + output_size=sum(output_sizes), + with_bias=with_bias, + add_bias=add_bias, + skip_quant=skip_quant, + weight_dtype=weight_dtype, + weight_key=weight_key, + ) + self.output_sizes = output_sizes + + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + model_format = getattr(param, "model_format", "") + loaded_weight = get_tensor(loaded_weight) + + if model_format == "torch": + loaded_weight = loaded_weight.transpose([1, 0]) + + assert loaded_shard_id in ["q_a", "kv_a"] + if not param._is_initialized(): + param.initialize() + + if loaded_shard_id == "q_a": + param_shard_offset = 0 + param_shard_size = self.output_sizes[0] + else: + # loaded_shard_id == "kv_a" + param_shard_offset = self.output_sizes[0] + param_shard_size = self.output_sizes[1] + + if hasattr(param, "tensor_track"): + param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) + param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size) + assert param.shape == loaded_weight.shape, ( + f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + ) + param.copy_(loaded_weight, False) + + class ColumnParallelLinear(LinearBase): """ ColumnParallelLinear Layer. diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 89c0a5d88..79c84d701 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -24,6 +24,7 @@ from paddle.nn.quant import weight_only_linear, weight_quantize from fastdeploy import envs from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, + MergedReplicatedLinear, QKVParallelLinear, ) from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs @@ -203,11 +204,15 @@ class WeightOnlyLinearMethod(QuantMethodBase): default_initializer=paddle.nn.initializer.Constant(0), ) quant_attrs = extra_weight_attrs - if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear): + if ( + isinstance(layer, MergedColumnParallelLinear) + or isinstance(layer, QKVParallelLinear) + or isinstance(layer, MergedReplicatedLinear) + ): quant_attrs = { **extra_weight_attrs, "tensor_track": TensorTracker( - shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim") + shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True) ), } set_weight_attrs( diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index c9bef1844..9058cda4a 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -38,6 +38,7 @@ from fastdeploy.model_executor.layers.linear import ( ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear, + MergedReplicatedLinear, ReplicatedLinear, RowParallelLinear, ) @@ -169,6 +170,13 @@ class DeepSeekV3MoE(nn.Layer): def load_state_dict(self, state_dict): """ """ + if self.experts.gate_correction_bias is not None: + gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key) + if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape: + gate_correction_bias_tensor = gate_correction_bias_tensor.reshape( + self.experts.gate_correction_bias.shape + ) + self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor) self.gate.load_state_dict(state_dict) self.experts.load_state_dict(state_dict) self.shared_experts.load_state_dict(state_dict) @@ -211,11 +219,11 @@ class DeepseekV3MLAAttention(nn.Layer): if self.q_lora_rank is not None: # NOTE: (changwenbin) qkv_a_proj horizontal fusion - self.qkv_a_proj_with_mqa = ReplicatedLinear( + self.qkv_a_proj_with_mqa = MergedReplicatedLinear( fd_config=fd_config, prefix=f"{prefix}.qkv_a_proj_with_mqa", input_size=self.hidden_size, - output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, + output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], with_bias=False, ) @@ -636,6 +644,8 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), ("experts.gate_correction_bias", "gate.e_score_correction_bias", None), + ("qkv_a_proj_with_mqa", "q_a_proj", "q_a"), + ("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", "kv_a"), ] # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( diff --git a/tests/model_loader/test_common_model.py b/tests/model_loader/test_common_model.py index f2c348195..07c179650 100644 --- a/tests/model_loader/test_common_model.py +++ b/tests/model_loader/test_common_model.py @@ -58,6 +58,19 @@ model_param_map = { {"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}}, ], }, + "DeepSeek-V3-0324": { + "tensor_parallel_size": 2, + "quantizations": [ + { + "quant_type": "wint4", + "env": { + "FD_ATTENTION_BACKEND": "MLA_ATTN", + "FLAGS_mla_use_tensorcore": "1", + "FLAGS_flash_attn_version": "3", + }, + }, + ], + }, }