From 4e392e83377bd0c08f8c82f15effc1e19df6a00f Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Fri, 28 Nov 2025 17:52:25 +0800 Subject: [PATCH] [BugFix]fix v1 loader lm head fp32 (#5270) (#5287) --- fastdeploy/model_executor/layers/mtp_linear.py | 4 ++-- fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py | 4 +++- tests/entrypoints/test_generation.py | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index 4250b611f..d6d6c9817 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -120,10 +120,10 @@ class ParallelEHProjection(nn.Layer): weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) if self.linear.weight.shape != weight_tensor.shape: weight_tensor = weight_tensor.transpose([1, 0]) - self.linear.weight.set_value(weight_tensor) + self.linear.weight.set_value(weight_tensor.astype(self.linear.weight.dtype)) if self.bias_key is not None: - bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) + bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.bias.dtype) self.linear.bias.set_value(bias) def forward(self, input): diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py index 74345d9d1..82249533d 100644 --- a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py +++ b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py @@ -251,7 +251,9 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM): self.model.load_state_dict(state_dict) self.visual.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.linear.weight.set_value( + self.model.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype) + ) else: self.lm_head.load_state_dict(state_dict) diff --git a/tests/entrypoints/test_generation.py b/tests/entrypoints/test_generation.py index 617a635ef..1e238cd35 100644 --- a/tests/entrypoints/test_generation.py +++ b/tests/entrypoints/test_generation.py @@ -50,6 +50,7 @@ class TestGeneration(unittest.TestCase): model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, + lm_head_fp32=True, engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT")), cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT")), )