mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[V1 Loader]Ernie VL support loader v1 (#3494)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* ernie vl support new loader * add unittest * fix test
This commit is contained in:
@@ -191,7 +191,7 @@ class FusedMoE(nn.Layer):
|
|||||||
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
||||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
|
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
|
||||||
else:
|
else:
|
||||||
expert_param = param[expert_id]
|
expert_param = param[expert_id - self.expert_id_offset]
|
||||||
loaded_weight = get_tensor(loaded_weight)
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
expert_param.copy_(loaded_weight, False)
|
expert_param.copy_(loaded_weight, False)
|
||||||
else:
|
else:
|
||||||
@@ -262,7 +262,7 @@ class FusedMoE(nn.Layer):
|
|||||||
loaded_weight,
|
loaded_weight,
|
||||||
shard_id,
|
shard_id,
|
||||||
):
|
):
|
||||||
expert_param = param[expert_id]
|
expert_param = param[expert_id - self.expert_id_offset]
|
||||||
if shard_id == "down":
|
if shard_id == "down":
|
||||||
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||||
elif shard_id in ["gate", "up"]:
|
elif shard_id in ["gate", "up"]:
|
||||||
@@ -279,6 +279,7 @@ class FusedMoE(nn.Layer):
|
|||||||
param_gate_up_proj_name: Optional[str] = None,
|
param_gate_up_proj_name: Optional[str] = None,
|
||||||
param_down_proj_name: Optional[str] = None,
|
param_down_proj_name: Optional[str] = None,
|
||||||
ckpt_expert_key_name: str = "experts",
|
ckpt_expert_key_name: str = "experts",
|
||||||
|
experts_offset: int = 0,
|
||||||
) -> list[tuple[str, str, int, str]]:
|
) -> list[tuple[str, str, int, str]]:
|
||||||
param_name_maping = []
|
param_name_maping = []
|
||||||
|
|
||||||
@@ -303,7 +304,7 @@ class FusedMoE(nn.Layer):
|
|||||||
expert_id,
|
expert_id,
|
||||||
shard_id,
|
shard_id,
|
||||||
)
|
)
|
||||||
for expert_id in range(num_experts)
|
for expert_id in range(experts_offset, experts_offset + num_experts)
|
||||||
for shard_id, weight_name in param_name_maping
|
for shard_id, weight_name in param_name_maping
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@@ -15,6 +15,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
@@ -32,7 +33,8 @@ from paddle.nn.functional.flash_attention import (
|
|||||||
)
|
)
|
||||||
from paddleformers.transformers.model_utils import PretrainedModel
|
from paddleformers.transformers.model_utils import PretrainedModel
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
from fastdeploy.model_executor.layers.utils import divide, get_tensor
|
||||||
|
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||||
|
|
||||||
from .activation import ACT2FN
|
from .activation import ACT2FN
|
||||||
from .configuration import DFNRopeVisionTransformerConfig
|
from .configuration import DFNRopeVisionTransformerConfig
|
||||||
@@ -153,11 +155,13 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
nn (_type_): _description_
|
nn (_type_): _description_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None:
|
def __init__(
|
||||||
|
self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1, tensor_parallel_rank: int = 0
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.tensor_parallel_degree = tensor_parallel_degree
|
self.tensor_parallel_degree = tensor_parallel_degree
|
||||||
|
self.tensor_parallel_rank = tensor_parallel_rank
|
||||||
if tensor_parallel_degree > 1:
|
if tensor_parallel_degree > 1:
|
||||||
self.qkv = ColumnParallelLinear(
|
self.qkv = ColumnParallelLinear(
|
||||||
dim,
|
dim,
|
||||||
@@ -175,11 +179,42 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
|
||||||
|
set_weight_attrs(self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True})
|
||||||
|
set_weight_attrs(self.qkv.bias, {"output_dim": True})
|
||||||
|
set_weight_attrs(self.proj.weight, {"output_dim": False})
|
||||||
else:
|
else:
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
self.head_dim = dim // num_heads # must added
|
self.head_dim = dim // num_heads # must added
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = dim
|
||||||
|
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)
|
||||||
|
|
||||||
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
|
load_bias = getattr(param, "load_bias", None)
|
||||||
|
if load_bias:
|
||||||
|
head_dim = self.hidden_size // self.num_heads
|
||||||
|
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
|
||||||
|
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
||||||
|
shard_weight = shard_weight.reshape([-1])
|
||||||
|
else:
|
||||||
|
shard_weight = loaded_weight[...].reshape(
|
||||||
|
[
|
||||||
|
self.hidden_size,
|
||||||
|
3,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
|
||||||
|
shard_weight = shard_weight.reshape([self.hidden_size, -1])
|
||||||
|
shard_weight = get_tensor(shard_weight)
|
||||||
|
assert param.shape == shard_weight.shape, (
|
||||||
|
f" Attempted to load weight ({shard_weight.shape}) " f"into parameter ({param.shape})"
|
||||||
|
)
|
||||||
|
param.copy_(shard_weight, False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -211,7 +246,6 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
.transpose(perm=[1, 0, 2, 3])
|
.transpose(perm=[1, 0, 2, 3])
|
||||||
)
|
)
|
||||||
q, k, v = qkv.unbind(axis=0)
|
q, k, v = qkv.unbind(axis=0)
|
||||||
|
|
||||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||||
|
|
||||||
@@ -233,7 +267,6 @@ class VisionFlashAttention2(nn.Layer):
|
|||||||
.squeeze(0)
|
.squeeze(0)
|
||||||
.reshape([seq_length, -1])
|
.reshape([seq_length, -1])
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.astype(paddle.float32)
|
attn_output = attn_output.astype(paddle.float32)
|
||||||
attn_output = self.proj(attn_output)
|
attn_output = self.proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
@@ -306,6 +339,9 @@ class VisionMlp(nn.Layer):
|
|||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(self.fc1.weight, {"output_dim": True})
|
||||||
|
set_weight_attrs(self.fc1.bias, {"output_dim": True})
|
||||||
|
set_weight_attrs(self.fc2.weight, {"output_dim": False})
|
||||||
else:
|
else:
|
||||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||||
@@ -365,6 +401,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
tensor_parallel_degree: int,
|
tensor_parallel_degree: int,
|
||||||
|
tensor_parallel_rank: int,
|
||||||
attn_implementation: str = "sdpa",
|
attn_implementation: str = "sdpa",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""_summary_
|
"""_summary_
|
||||||
@@ -382,6 +419,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
config.embed_dim,
|
config.embed_dim,
|
||||||
num_heads=config.num_heads,
|
num_heads=config.num_heads,
|
||||||
tensor_parallel_degree=tensor_parallel_degree,
|
tensor_parallel_degree=tensor_parallel_degree,
|
||||||
|
tensor_parallel_rank=tensor_parallel_rank,
|
||||||
)
|
)
|
||||||
self.mlp = VisionMlp(
|
self.mlp = VisionMlp(
|
||||||
dim=config.embed_dim,
|
dim=config.embed_dim,
|
||||||
@@ -407,7 +445,9 @@ class DFNRopeVisionBlock(nn.Layer):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -478,6 +518,7 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
|||||||
DFNRopeVisionBlock(
|
DFNRopeVisionBlock(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
config.pretrained_config.tensor_parallel_degree,
|
config.pretrained_config.tensor_parallel_degree,
|
||||||
|
config.pretrained_config.tensor_parallel_rank,
|
||||||
)
|
)
|
||||||
for _ in range(config.vision_config.depth)
|
for _ in range(config.vision_config.depth)
|
||||||
]
|
]
|
||||||
|
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@@ -562,6 +563,93 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return "Ernie4_5_VLMoeForConditionalGeneration"
|
return "Ernie4_5_VLMoeForConditionalGeneration"
|
||||||
|
|
||||||
|
def gate_correction_bias_loader(self, params_dict, loaded_weight_name, loaded_weight):
|
||||||
|
text_param_name = loaded_weight_name.replace(
|
||||||
|
"moe_statics.e_score_correction_bias", "text_fused_moe.experts.gate_correction_bias"
|
||||||
|
)
|
||||||
|
image_param_name = loaded_weight_name.replace(
|
||||||
|
"moe_statics.e_score_correction_bias", "image_fused_moe.experts.gate_correction_bias"
|
||||||
|
)
|
||||||
|
text_param = params_dict[text_param_name]
|
||||||
|
image_param = params_dict[image_param_name]
|
||||||
|
loaded_weight = get_tensor(loaded_weight)
|
||||||
|
text_param.copy_(loaded_weight[0].unsqueeze(0), False)
|
||||||
|
image_param.copy_(loaded_weight[1].unsqueeze(0), False)
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def load_weights(self, weights_iterator) -> None:
|
||||||
|
"""
|
||||||
|
Load model parameters from a given weights_iterator object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.models.utils import default_weight_loader
|
||||||
|
|
||||||
|
general_params_mapping = [
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
("embed_tokens.embeddings", "embed_tokens", None, None),
|
||||||
|
("lm_head.linear", "lm_head", None, None),
|
||||||
|
("mlp.image_fused_moe.gate.weight", "mlp.gate.weight_1", None, "gate"),
|
||||||
|
("mlp.text_fused_moe.gate.weight", "mlp.gate.weight", None, "gate"),
|
||||||
|
("resampler_model", "ernie.resampler_model", None, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
text_expert_params_mapping = []
|
||||||
|
if getattr(self.fd_config.model_config, "moe_num_experts", None) is not None:
|
||||||
|
text_expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
num_experts=self.fd_config.model_config.moe_num_experts[0],
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_gate_up_proj_name="up_gate_proj",
|
||||||
|
param_gate_up_proj_name="text_fused_moe.experts.up_gate_proj_",
|
||||||
|
param_down_proj_name="text_fused_moe.experts.down_proj_",
|
||||||
|
)
|
||||||
|
image_expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
num_experts=self.fd_config.model_config.moe_num_experts[1],
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_gate_up_proj_name="up_gate_proj",
|
||||||
|
param_gate_up_proj_name="image_fused_moe.experts.up_gate_proj_",
|
||||||
|
param_down_proj_name="image_fused_moe.experts.down_proj_",
|
||||||
|
experts_offset=self.fd_config.model_config.moe_num_experts[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
expert_id = None
|
||||||
|
shard_id = None
|
||||||
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
|
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
||||||
|
if weight_name not in loaded_weight_name:
|
||||||
|
continue
|
||||||
|
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[model_param_name]
|
||||||
|
expert_id = exp_id
|
||||||
|
shard_id = shard_id
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# text and image gate_correction_bias is fused in ckpt and need load independently
|
||||||
|
if "moe_statics.e_score_correction_bias" in loaded_weight_name:
|
||||||
|
self.gate_correction_bias_loader(params_dict, loaded_weight_name, loaded_weight)
|
||||||
|
continue
|
||||||
|
if loaded_weight_name not in params_dict.keys():
|
||||||
|
continue
|
||||||
|
model_param_name = loaded_weight_name
|
||||||
|
param = params_dict[model_param_name]
|
||||||
|
|
||||||
|
# Get weight loader from parameter and set weight
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||||
|
sig = inspect.signature(weight_loader)
|
||||||
|
|
||||||
|
if "expert_id" in sig.parameters:
|
||||||
|
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
|
||||||
|
else:
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
if self.tie_word_embeddings:
|
||||||
|
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||||
"""
|
"""
|
||||||
@@ -715,7 +803,6 @@ class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
|||||||
"""
|
"""
|
||||||
get_tensor_parallel_mappings
|
get_tensor_parallel_mappings
|
||||||
"""
|
"""
|
||||||
logger.info("erine inference model _get_tensor_parallel_mappings")
|
|
||||||
from fastdeploy.model_executor.models.tp_utils import (
|
from fastdeploy.model_executor.models.tp_utils import (
|
||||||
build_expanded_keys,
|
build_expanded_keys,
|
||||||
has_prefix,
|
has_prefix,
|
||||||
|
@@ -30,6 +30,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import (
|
|||||||
reduce_scatter_group,
|
reduce_scatter_group,
|
||||||
scatter_axis,
|
scatter_axis,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.models.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
class ScatterOp(PyLayer):
|
class ScatterOp(PyLayer):
|
||||||
@@ -201,7 +202,6 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].bias)
|
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].bias)
|
||||||
_set_var_distributed(self.spatial_linear[idx].weight, split_axis=0)
|
_set_var_distributed(self.spatial_linear[idx].weight, split_axis=0)
|
||||||
_set_var_distributed(self.spatial_linear[idx].bias, split_axis=0)
|
_set_var_distributed(self.spatial_linear[idx].bias, split_axis=0)
|
||||||
|
|
||||||
if self.use_temporal_conv:
|
if self.use_temporal_conv:
|
||||||
for idx in [0, 2, 3]:
|
for idx in [0, 2, 3]:
|
||||||
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].weight)
|
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].weight)
|
||||||
@@ -210,6 +210,7 @@ class VariableResolutionResamplerModel(nn.Layer):
|
|||||||
mark_as_sequence_parallel_parameter(self.mlp.weight)
|
mark_as_sequence_parallel_parameter(self.mlp.weight)
|
||||||
mark_as_sequence_parallel_parameter(self.mlp.bias)
|
mark_as_sequence_parallel_parameter(self.mlp.bias)
|
||||||
mark_as_sequence_parallel_parameter(self.after_norm.weight)
|
mark_as_sequence_parallel_parameter(self.after_norm.weight)
|
||||||
|
set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False})
|
||||||
|
|
||||||
def spatial_conv_reshape(self, x, spatial_conv_size):
|
def spatial_conv_reshape(self, x, spatial_conv_size):
|
||||||
"""
|
"""
|
||||||
|
227
tests/model_loader/test_load_ernie_vl.py
Normal file
227
tests/model_loader/test_load_ernie_vl.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Read ports from environment variables; use default values if not set
|
||||||
|
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||||
|
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||||
|
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||||
|
|
||||||
|
# List of ports to clean before and after tests
|
||||||
|
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT]
|
||||||
|
|
||||||
|
|
||||||
|
def is_port_open(host: str, port: int, timeout=1.0):
|
||||||
|
"""
|
||||||
|
Check if a TCP port is open on the given host.
|
||||||
|
Returns True if connection succeeds, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with socket.create_connection((host, port), timeout):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def kill_process_on_port(port: int):
|
||||||
|
"""
|
||||||
|
Kill processes that are listening on the given port.
|
||||||
|
Uses `lsof` to find process ids and sends SIGKILL.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
|
||||||
|
for pid in output.splitlines():
|
||||||
|
os.kill(int(pid), signal.SIGKILL)
|
||||||
|
print(f"Killed process on port {port}, pid={pid}")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def clean_ports():
|
||||||
|
"""
|
||||||
|
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
|
||||||
|
"""
|
||||||
|
for port in PORTS_TO_CLEAN:
|
||||||
|
kill_process_on_port(port)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def setup_and_run_server():
|
||||||
|
"""
|
||||||
|
Pytest fixture that runs once per test session:
|
||||||
|
- Cleans ports before tests
|
||||||
|
- Starts the API server as a subprocess
|
||||||
|
- Waits for server port to open (up to 30 seconds)
|
||||||
|
- Tears down server after all tests finish
|
||||||
|
"""
|
||||||
|
print("Pre-test port cleanup...")
|
||||||
|
clean_ports()
|
||||||
|
|
||||||
|
base_path = os.getenv("MODEL_PATH")
|
||||||
|
if base_path:
|
||||||
|
model_path = os.path.join(base_path, "ernie-4_5-vl-28b-a3b-bf16-paddle")
|
||||||
|
else:
|
||||||
|
model_path = "./ernie-4_5-vl-28b-a3b-bf16-paddle"
|
||||||
|
|
||||||
|
log_path = "server.log"
|
||||||
|
limit_mm_str = json.dumps({"image": 100, "video": 100})
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"fastdeploy.entrypoints.openai.api_server",
|
||||||
|
"--model",
|
||||||
|
model_path,
|
||||||
|
"--port",
|
||||||
|
str(FD_API_PORT),
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
"2",
|
||||||
|
"--engine-worker-queue-port",
|
||||||
|
str(FD_ENGINE_QUEUE_PORT),
|
||||||
|
"--metrics-port",
|
||||||
|
str(FD_METRICS_PORT),
|
||||||
|
"--enable-mm",
|
||||||
|
"--max-model-len",
|
||||||
|
"32768",
|
||||||
|
"--max-num-batched-tokens",
|
||||||
|
"384",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--limit-mm-per-prompt",
|
||||||
|
limit_mm_str,
|
||||||
|
"--enable-chunked-prefill",
|
||||||
|
"--kv-cache-ratio",
|
||||||
|
"0.71",
|
||||||
|
"--reasoning-parser",
|
||||||
|
"ernie-45-vl",
|
||||||
|
"--load_choices",
|
||||||
|
"default_v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start subprocess in new process group
|
||||||
|
with open(log_path, "w") as logfile:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=logfile,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
start_new_session=True, # Enables killing full group via os.killpg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait up to 10 minutes for API server to be ready
|
||||||
|
for _ in range(10 * 60):
|
||||||
|
if is_port_open("127.0.0.1", FD_API_PORT):
|
||||||
|
print(f"API server is up on port {FD_API_PORT}")
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
|
||||||
|
try:
|
||||||
|
os.killpg(process.pid, signal.SIGTERM)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to kill process group: {e}")
|
||||||
|
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
|
||||||
|
|
||||||
|
yield # Run tests
|
||||||
|
|
||||||
|
print("\n===== Post-test server cleanup... =====")
|
||||||
|
try:
|
||||||
|
os.killpg(process.pid, signal.SIGTERM)
|
||||||
|
print(f"API server (pid={process.pid}) terminated")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to terminate API server: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def api_url(request):
|
||||||
|
"""
|
||||||
|
Returns the API endpoint URL for chat completions.
|
||||||
|
"""
|
||||||
|
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def metrics_url(request):
|
||||||
|
"""
|
||||||
|
Returns the metrics endpoint URL.
|
||||||
|
"""
|
||||||
|
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def headers():
|
||||||
|
"""
|
||||||
|
Returns common HTTP request headers.
|
||||||
|
"""
|
||||||
|
return {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# OpenAI Client Chat Completion Test
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_client():
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
service_http_port = str(FD_API_PORT)
|
||||||
|
client = openai.Client(
|
||||||
|
base_url=f"http://{ip}:{service_http_port}/v1",
|
||||||
|
api_key="EMPTY_API_KEY",
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
# Non-streaming test
|
||||||
|
def test_non_streaming_chat(openai_client):
|
||||||
|
"""Test non-streaming chat functionality with the local service"""
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful AI assistant.",
|
||||||
|
}, # system不是必需,可选
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "请描述图片内容"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=53,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hasattr(response, "choices")
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
assert hasattr(response.choices[0], "message")
|
||||||
|
assert hasattr(response.choices[0].message, "content")
|
Reference in New Issue
Block a user