diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index d0a366e38..fe8910211 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -720,6 +720,7 @@ class KVBatchLinear(LinearBase): self.v_head_dim = v_head_dim # Split num_attention_heads when using TP inference. self.num_heads_per_partition = divide(num_attention_heads, self.nranks) + self.local_rank = fd_config.parallel_config.tensor_parallel_rank # Initialize parent with combined dimensions super().__init__( @@ -738,6 +739,63 @@ class KVBatchLinear(LinearBase): self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight" self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight" + self.k_b_proj_weight = self.create_parameter( + shape=[self.num_heads_per_partition, self.qk_nope_head_dim, self.kv_lora_rank], + dtype=self.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + self.v_b_proj_weight = self.create_parameter( + shape=[self.num_heads_per_partition, self.kv_lora_rank, self.v_head_dim], + dtype=self.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + set_weight_attrs( + self.k_b_proj_weight, + {"weight_loader": self.weight_loader}, + ) + + if self.nranks > 0: + _set_var_distributed(self.k_b_proj_weight, split_axis=1) + set_weight_attrs(self.k_b_proj_weight, {"output_dim": True}) + + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + output_dim = getattr(param, "output_dim", None) + # Tensor parallelism splits the weight along the output_dim + if output_dim is not None: + dim = -1 + size = loaded_weight.get_shape()[dim] + block_size = size // self.nranks + shard_offset = self.local_rank * block_size + shard_size = (self.local_rank + 1) * block_size + loaded_weight = loaded_weight[..., shard_offset:shard_size] + w = ( + get_tensor(loaded_weight) + .reshape( + [ + self.kv_lora_rank, + self.num_heads_per_partition, + -1, + ] + ) + .transpose(perm=[1, 2, 0]) + ) + if param.dtype != w.dtype: + w = w.cast(param.dtype) + # Split into K and V weights + # wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank] + wk_b = w[:, : self.qk_nope_head_dim, :] + if self.v_head_dim is None: + raise ValueError("self.v_head_dim should not be None") + # wv_b: [num_heads, kv_lora_rank, v_head_dim] + wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1]) + + self.k_b_proj_weight.set_value(wk_b) + self.v_b_proj_weight.set_value(wv_b) + def load_state_dict(self, state_dict: dict): """ Load the combined KV weight and split it into K and V projections diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 9c21fbb98..2be90f8f9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -52,7 +52,7 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ scores = paddle.nn.functional.sigmoid(gating_output) - scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c46bbac72..16b75e9e2 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -508,10 +508,11 @@ class FusedMoE(nn.Layer): gate_correction_bias_tensor = self.extract_gate_correction_bias( self.gate_correction_bias_key, state_dict ) + if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape: + gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape) self.gate_correction_bias.set_value(gate_correction_bias_tensor) else: self.gate_correction_bias = None - else: self.gate_correction_bias = None diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 967909645..03f6cea76 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -628,6 +628,79 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): self.model.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict) + @paddle.no_grad() + def load_weights(self, weights_iterator) -> None: + """ + Load model parameters from a given weights_iterator object. + Args: + weights_iterator (Iterator): An iterator yielding (name, weight) pairs. + """ + from fastdeploy.model_executor.models.utils import default_weight_loader + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("up_gate_proj", "gate_proj", "gate"), + ("up_gate_proj", "up_proj", "up"), + ("embed_tokens.embeddings", "embed_tokens", None), + ("lm_head.linear", "lm_head", None), + ("experts.gate_correction_bias", "gate.e_score_correction_bias", None), + ] + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + param_gate_up_proj_name="experts.up_gate_proj_", + param_down_proj_name="experts.down_proj_", + num_experts=self.fd_config.model_config.n_routed_experts, + ) + params_dict = dict(self.named_parameters()) + + for loaded_weight_name, loaded_weight in weights_iterator: + loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model") + loaded_weight_name = loaded_weight_name.replace("layers", "decoder_layers") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in loaded_weight_name: + continue + if "mlp.experts." in loaded_weight_name and loaded_weight_name not in params_dict: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + + if model_param_name not in params_dict: + continue + + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in loaded_weight_name: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) + break + else: + if loaded_weight_name not in params_dict: + continue + param = params_dict[loaded_weight_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight) + if "kv_b_proj.weight" in loaded_weight_name: + # handle kv_b_proj_bmm + model_param_name = loaded_weight_name.replace( + "kv_b_proj.weight", "kv_b_proj_bmm.k_b_proj_weight" + ) + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", None) + weight_loader(param, loaded_weight, shard_id) + def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index e2caf21b8..1d2f21a82 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -78,9 +78,13 @@ def default_weight_loader(fd_config: FDConfig) -> None: if param.dtype != loaded_weight.dtype: loaded_weight = loaded_weight.cast(param.dtype) - assert param.shape == loaded_weight.shape, ( - f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" - ) + if param.shape != loaded_weight.shape: + try: + param = param.reshape(loaded_weight.shape) + except ValueError as e: + raise ValueError( + f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}" + ) param.copy_(loaded_weight, False) except Exception: