From 35f85baf099bc704f68f4ffa20aadbf0a1d7a759 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Thu, 27 Nov 2025 20:12:56 +0800 Subject: [PATCH] [BugFix]fix v1 loader lm head fp32 (#5270) --- fastdeploy/model_executor/layers/mtp_linear.py | 4 ++-- fastdeploy/model_executor/load_weight_utils.py | 4 +++- fastdeploy/model_executor/models/ernie4_5_moe.py | 4 +++- .../model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py | 4 +++- fastdeploy/model_executor/models/qwen2.py | 4 +++- fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py | 8 ++++++-- fastdeploy/model_executor/models/qwen3.py | 4 +++- tests/entrypoints/test_generation.py | 1 + 8 files changed, 24 insertions(+), 9 deletions(-) diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index 42493a1d3..b1699720b 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -120,10 +120,10 @@ class ParallelEHProjection(nn.Layer): weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) if self.linear.weight.shape != weight_tensor.shape: weight_tensor = weight_tensor.transpose([1, 0]) - self.linear.weight.set_value(weight_tensor) + self.linear.weight.set_value(weight_tensor.astype(self.linear.weight.dtype)) if self.bias_key is not None: - bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) + bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype) self.linear.bias.set_value(bias) def forward(self, input): diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 6bf85c068..ff485cbbe 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -68,7 +68,9 @@ def load_weights_from_cache(model, weights_iterator): ) param.copy_(loaded_weight, False) if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False): - model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0])) + model.lm_head.linear.weight.set_value( + loaded_weight.transpose([1, 0]).astype(model.lm_head.linear.weight.dtype) + ) for _, model_sublayer in model.named_sublayers(): if isinstance(model_sublayer, KVBatchLinear): model_sublayer.process_weights_after_loading() diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index b2ec3fbc9..5140d9632 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -600,7 +600,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) def compute_logits(self, hidden_states: paddle.Tensor): logits = self.lm_head(hidden_states) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 2e357579b..7c1ff3eaf 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -720,7 +720,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): ) process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) @paddle.no_grad() def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 0a84248b9..3b3baee62 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -376,7 +376,9 @@ class Qwen2ForCausalLM(ModelForCasualLM): model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) @classmethod def name(self): diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py index 0f17ec08f..4e751ca9e 100644 --- a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py +++ b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py @@ -232,7 +232,9 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM): process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.model.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) @paddle.no_grad() def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): @@ -247,7 +249,9 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM): self.model.load_state_dict(state_dict) self.visual.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.model.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) else: self.lm_head.load_state_dict(state_dict) diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 3fb20da95..67bccc358 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -319,7 +319,9 @@ class Qwen3ForCausalLM(ModelForCasualLM): process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings and not is_pooling_model: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.model.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) @paddle.no_grad() def set_state_dict(self, state_dict): diff --git a/tests/entrypoints/test_generation.py b/tests/entrypoints/test_generation.py index 617a635ef..1e238cd35 100644 --- a/tests/entrypoints/test_generation.py +++ b/tests/entrypoints/test_generation.py @@ -50,6 +50,7 @@ class TestGeneration(unittest.TestCase): model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, + lm_head_fp32=True, engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT")), cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT")), )