From 85fbf5455a6844d80cc4e6d3314248982e8ea06c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 22 Aug 2025 11:16:57 +0800 Subject: [PATCH] [V1 Loader]Ernie VL support loader v1 (#3494) * ernie vl support new loader * add unittest * fix test --- fastdeploy/model_executor/layers/moe/moe.py | 7 +- .../models/ernie4_5_vl/dfnrope/modeling.py | 51 +++- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 89 ++++++- .../models/ernie4_5_vl/modeling_resampler.py | 3 +- tests/model_loader/test_load_ernie_vl.py | 227 ++++++++++++++++++ 5 files changed, 367 insertions(+), 10 deletions(-) create mode 100644 tests/model_loader/test_load_ernie_vl.py diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 646578f35..28b9afdbe 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -191,7 +191,7 @@ class FusedMoE(nn.Layer): loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size] self.weight_loader(param, loaded_weight_shard, expert_id, shard_id) else: - expert_param = param[expert_id] + expert_param = param[expert_id - self.expert_id_offset] loaded_weight = get_tensor(loaded_weight) expert_param.copy_(loaded_weight, False) else: @@ -262,7 +262,7 @@ class FusedMoE(nn.Layer): loaded_weight, shard_id, ): - expert_param = param[expert_id] + expert_param = param[expert_id - self.expert_id_offset] if shard_id == "down": self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id) elif shard_id in ["gate", "up"]: @@ -279,6 +279,7 @@ class FusedMoE(nn.Layer): param_gate_up_proj_name: Optional[str] = None, param_down_proj_name: Optional[str] = None, ckpt_expert_key_name: str = "experts", + experts_offset: int = 0, ) -> list[tuple[str, str, int, str]]: param_name_maping = [] @@ -303,7 +304,7 @@ class FusedMoE(nn.Layer): expert_id, shard_id, ) - for expert_id in range(num_experts) + for expert_id in range(experts_offset, experts_offset + num_experts) for shard_id, weight_name in param_name_maping ] diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py index 2dcf07559..fcfd80ec3 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py @@ -15,6 +15,7 @@ """ from functools import partial +from typing import Optional import numpy as np import paddle @@ -32,7 +33,8 @@ from paddle.nn.functional.flash_attention import ( ) from paddleformers.transformers.model_utils import PretrainedModel -from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.layers.utils import divide, get_tensor +from fastdeploy.model_executor.models.utils import set_weight_attrs from .activation import ACT2FN from .configuration import DFNRopeVisionTransformerConfig @@ -153,11 +155,13 @@ class VisionFlashAttention2(nn.Layer): nn (_type_): _description_ """ - def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None: + def __init__( + self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1, tensor_parallel_rank: int = 0 + ) -> None: super().__init__() self.num_heads = num_heads self.tensor_parallel_degree = tensor_parallel_degree - + self.tensor_parallel_rank = tensor_parallel_rank if tensor_parallel_degree > 1: self.qkv = ColumnParallelLinear( dim, @@ -175,11 +179,42 @@ class VisionFlashAttention2(nn.Layer): input_is_parallel=True, has_bias=True, ) + set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader}) + set_weight_attrs(self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True}) + set_weight_attrs(self.qkv.bias, {"output_dim": True}) + set_weight_attrs(self.proj.weight, {"output_dim": False}) else: self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) self.proj = nn.Linear(dim, dim) self.head_dim = dim // num_heads # must added + self.num_heads = num_heads + self.hidden_size = dim + self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree) + + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + load_bias = getattr(param, "load_bias", None) + if load_bias: + head_dim = self.hidden_size // self.num_heads + shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim]) + shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank] + shard_weight = shard_weight.reshape([-1]) + else: + shard_weight = loaded_weight[...].reshape( + [ + self.hidden_size, + 3, + self.num_heads, + self.head_dim, + ] + ) + shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank] + shard_weight = shard_weight.reshape([self.hidden_size, -1]) + shard_weight = get_tensor(shard_weight) + assert param.shape == shard_weight.shape, ( + f" Attempted to load weight ({shard_weight.shape}) " f"into parameter ({param.shape})" + ) + param.copy_(shard_weight, False) def forward( self, @@ -211,7 +246,6 @@ class VisionFlashAttention2(nn.Layer): .transpose(perm=[1, 0, 2, 3]) ) q, k, v = qkv.unbind(axis=0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0) @@ -233,7 +267,6 @@ class VisionFlashAttention2(nn.Layer): .squeeze(0) .reshape([seq_length, -1]) ) - attn_output = attn_output.astype(paddle.float32) attn_output = self.proj(attn_output) return attn_output @@ -306,6 +339,9 @@ class VisionMlp(nn.Layer): input_is_parallel=True, has_bias=True, ) + set_weight_attrs(self.fc1.weight, {"output_dim": True}) + set_weight_attrs(self.fc1.bias, {"output_dim": True}) + set_weight_attrs(self.fc2.weight, {"output_dim": False}) else: self.fc1 = nn.Linear(dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, dim) @@ -365,6 +401,7 @@ class DFNRopeVisionBlock(nn.Layer): self, config, tensor_parallel_degree: int, + tensor_parallel_rank: int, attn_implementation: str = "sdpa", ) -> None: """_summary_ @@ -382,6 +419,7 @@ class DFNRopeVisionBlock(nn.Layer): config.embed_dim, num_heads=config.num_heads, tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, ) self.mlp = VisionMlp( dim=config.embed_dim, @@ -407,7 +445,9 @@ class DFNRopeVisionBlock(nn.Layer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states @@ -478,6 +518,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel): DFNRopeVisionBlock( config.vision_config, config.pretrained_config.tensor_parallel_degree, + config.pretrained_config.tensor_parallel_rank, ) for _ in range(config.vision_config.depth) ] diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 3789320fd..92146b19a 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -16,6 +16,7 @@ from __future__ import annotations +import inspect from dataclasses import dataclass from functools import partial from typing import Dict, Optional, Union @@ -562,6 +563,93 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): def name(self): return "Ernie4_5_VLMoeForConditionalGeneration" + def gate_correction_bias_loader(self, params_dict, loaded_weight_name, loaded_weight): + text_param_name = loaded_weight_name.replace( + "moe_statics.e_score_correction_bias", "text_fused_moe.experts.gate_correction_bias" + ) + image_param_name = loaded_weight_name.replace( + "moe_statics.e_score_correction_bias", "image_fused_moe.experts.gate_correction_bias" + ) + text_param = params_dict[text_param_name] + image_param = params_dict[image_param_name] + loaded_weight = get_tensor(loaded_weight) + text_param.copy_(loaded_weight[0].unsqueeze(0), False) + image_param.copy_(loaded_weight[1].unsqueeze(0), False) + + @paddle.no_grad() + def load_weights(self, weights_iterator) -> None: + """ + Load model parameters from a given weights_iterator object. + + Args: + weights_iterator (Iterator): An iterator yielding (name, weight) pairs. + """ + + from fastdeploy.model_executor.models.utils import default_weight_loader + + general_params_mapping = [ + # (param_name, weight_name, expert_id, shard_id) + ("embed_tokens.embeddings", "embed_tokens", None, None), + ("lm_head.linear", "lm_head", None, None), + ("mlp.image_fused_moe.gate.weight", "mlp.gate.weight_1", None, "gate"), + ("mlp.text_fused_moe.gate.weight", "mlp.gate.weight", None, "gate"), + ("resampler_model", "ernie.resampler_model", None, None), + ] + + text_expert_params_mapping = [] + if getattr(self.fd_config.model_config, "moe_num_experts", None) is not None: + text_expert_params_mapping = FusedMoE.make_expert_params_mapping( + num_experts=self.fd_config.model_config.moe_num_experts[0], + ckpt_down_proj_name="down_proj", + ckpt_gate_up_proj_name="up_gate_proj", + param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_", + param_down_proj_name="text_fused_moe.experts.down_proj_", + ) + image_expert_params_mapping = FusedMoE.make_expert_params_mapping( + num_experts=self.fd_config.model_config.moe_num_experts[1], + ckpt_down_proj_name="down_proj", + ckpt_gate_up_proj_name="up_gate_proj", + param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_", + param_down_proj_name="image_fused_moe.experts.down_proj_", + experts_offset=self.fd_config.model_config.moe_num_experts[0], + ) + + all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping + + params_dict = dict(self.named_parameters()) + expert_id = None + shard_id = None + for loaded_weight_name, loaded_weight in weights_iterator: + for param_name, weight_name, exp_id, shard_id in all_param_mapping: + if weight_name not in loaded_weight_name: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + param = params_dict[model_param_name] + expert_id = exp_id + shard_id = shard_id + break + else: + # text and image gate_correction_bias is fused in ckpt and need load independently + if "moe_statics.e_score_correction_bias" in loaded_weight_name: + self.gate_correction_bias_loader(params_dict, loaded_weight_name, loaded_weight) + continue + if loaded_weight_name not in params_dict.keys(): + continue + model_param_name = loaded_weight_name + param = params_dict[model_param_name] + + # Get weight loader from parameter and set weight + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + sig = inspect.signature(weight_loader) + + if "expert_id" in sig.parameters: + weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + else: + weight_loader(param, loaded_weight) + + if self.tie_word_embeddings: + self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + @paddle.no_grad() def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): """ @@ -715,7 +803,6 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel): """ get_tensor_parallel_mappings """ - logger.info("erine inference model _get_tensor_parallel_mappings") from fastdeploy.model_executor.models.tp_utils import ( build_expanded_keys, has_prefix, diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py index b032747d4..80e664e49 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py @@ -30,6 +30,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import ( reduce_scatter_group, scatter_axis, ) +from fastdeploy.model_executor.models.utils import set_weight_attrs class ScatterOp(PyLayer): @@ -201,7 +202,6 @@ class VariableResolutionResamplerModel(nn.Layer): mark_as_sequence_parallel_parameter(self.spatial_linear[idx].bias) _set_var_distributed(self.spatial_linear[idx].weight, split_axis=0) _set_var_distributed(self.spatial_linear[idx].bias, split_axis=0) - if self.use_temporal_conv: for idx in [0, 2, 3]: mark_as_sequence_parallel_parameter(self.temporal_linear[idx].weight) @@ -210,6 +210,7 @@ class VariableResolutionResamplerModel(nn.Layer): mark_as_sequence_parallel_parameter(self.mlp.weight) mark_as_sequence_parallel_parameter(self.mlp.bias) mark_as_sequence_parallel_parameter(self.after_norm.weight) + set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False}) def spatial_conv_reshape(self, x, spatial_conv_size): """ diff --git a/tests/model_loader/test_load_ernie_vl.py b/tests/model_loader/test_load_ernie_vl.py new file mode 100644 index 000000000..81c00af68 --- /dev/null +++ b/tests/model_loader/test_load_ernie_vl.py @@ -0,0 +1,227 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import signal +import socket +import subprocess +import sys +import time + +import openai +import pytest + +# Read ports from environment variables; use default values if not set +FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) +FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) +FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT] + + +def is_port_open(host: str, port: int, timeout=1.0): + """ + Check if a TCP port is open on the given host. + Returns True if connection succeeds, False otherwise. + """ + try: + with socket.create_connection((host, port), timeout): + return True + except Exception: + return False + + +def kill_process_on_port(port: int): + """ + Kill processes that are listening on the given port. + Uses `lsof` to find process ids and sends SIGKILL. + """ + try: + output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() + for pid in output.splitlines(): + os.kill(int(pid), signal.SIGKILL) + print(f"Killed process on port {port}, pid={pid}") + except subprocess.CalledProcessError: + pass + + +def clean_ports(): + """ + Kill all processes occupying the ports listed in PORTS_TO_CLEAN. + """ + for port in PORTS_TO_CLEAN: + kill_process_on_port(port) + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean_ports() + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ernie-4_5-vl-28b-a3b-bf16-paddle") + else: + model_path = "./ernie-4_5-vl-28b-a3b-bf16-paddle" + + log_path = "server.log" + limit_mm_str = json.dumps({"image": 100, "video": 100}) + + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "2", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--enable-mm", + "--max-model-len", + "32768", + "--max-num-batched-tokens", + "384", + "--max-num-seqs", + "128", + "--limit-mm-per-prompt", + limit_mm_str, + "--enable-chunked-prefill", + "--kv-cache-ratio", + "0.71", + "--reasoning-parser", + "ernie-45-vl", + "--load_choices", + "default_v1", + ] + + # Start subprocess in new process group + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + ) + + # Wait up to 10 minutes for API server to be ready + for _ in range(10 * 60): + if is_port_open("127.0.0.1", FD_API_PORT): + print(f"API server is up on port {FD_API_PORT}") + break + time.sleep(1) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process.pid, signal.SIGTERM) + print(f"API server (pid={process.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +# ========================== +# OpenAI Client Chat Completion Test +# ========================== + + +@pytest.fixture +def openai_client(): + ip = "0.0.0.0" + service_http_port = str(FD_API_PORT) + client = openai.Client( + base_url=f"http://{ip}:{service_http_port}/v1", + api_key="EMPTY_API_KEY", + ) + return client + + +# Non-streaming test +def test_non_streaming_chat(openai_client): + """Test non-streaming chat functionality with the local service""" + response = openai_client.chat.completions.create( + model="default", + messages=[ + { + "role": "system", + "content": "You are a helpful AI assistant.", + }, # system不是必需,可选 + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0", + "detail": "high", + }, + }, + {"type": "text", "text": "请描述图片内容"}, + ], + }, + ], + temperature=1, + max_tokens=53, + stream=False, + ) + + assert hasattr(response, "choices") + assert len(response.choices) > 0 + assert hasattr(response.choices[0], "message") + assert hasattr(response.choices[0].message, "content")