mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Model] Qwen2.5VL support --use-cudagraph and unit testing (#4087)
* [BugFix] qwen2.5vl enable_thinking=true and image_patch_id bug fix * [Docs]offine infer add apply_chat_template add_generation_prompt parameter * [Model]qwen2.5VL support --use-cudagraph * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v2 * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v3 * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v4 * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v5 * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v6 * [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v7
This commit is contained in:
@@ -107,7 +107,7 @@ messages = [
|
||||
}
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
images, videos = [], []
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
@@ -107,7 +107,7 @@ messages = [
|
||||
}
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
images, videos = [], []
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
@@ -27,6 +27,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
support_graph_optimization,
|
||||
)
|
||||
@@ -39,12 +40,6 @@ from fastdeploy.model_executor.models.model_base import (
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import extract_text_token_output
|
||||
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@support_graph_optimization
|
||||
@@ -108,31 +103,17 @@ class Qwen2_5_VLModel(nn.Layer):
|
||||
logger.info(f"Start load layer {i}")
|
||||
self.layers[i].load_state_dict(state_dict)
|
||||
|
||||
def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor:
|
||||
return self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_embeddings: paddle.Tensor,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor],
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
|
||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
# -----------------------
|
||||
# 将 image_embeds 替换 input_embeds 里的 image video 占位符
|
||||
image_mask = ids_remove_padding == self.image_token_id
|
||||
image_token_num = image_mask.sum()
|
||||
|
||||
video_mask = ids_remove_padding == self.video_token_id
|
||||
video_token_num = video_mask.sum()
|
||||
|
||||
# 由于框架只有 image_features,所以目前不支持图片和视频混合
|
||||
# TODO(wangyafeng) 后续考虑支持传入 video_features
|
||||
if image_token_num > 0:
|
||||
hidden_states[image_mask] = image_features.cast(self._dtype)
|
||||
if video_token_num > 0:
|
||||
hidden_states[video_mask] = image_features.cast(self._dtype)
|
||||
|
||||
# -----------------------
|
||||
hidden_states = input_embeddings
|
||||
|
||||
residual = None
|
||||
for i in range(self.num_layers):
|
||||
@@ -144,18 +125,6 @@ class Qwen2_5_VLModel(nn.Layer):
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# -----------------------
|
||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
|
||||
hidden_states = extract_text_token_output(
|
||||
max_seq_len,
|
||||
max_seq_len_index.cast("int32"),
|
||||
image_token_num.cast("int32"),
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
hidden_states.cast("float32"),
|
||||
).cast(self._dtype)
|
||||
# -----------------------
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
|
||||
return out
|
||||
@@ -183,6 +152,12 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||
# ----------- language model -------------
|
||||
self.model = Qwen2_5_VLModel(fd_config=fd_config)
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self._input_embeddings = paddle.zeros(
|
||||
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
@@ -246,14 +221,42 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
||||
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor] = None,
|
||||
) -> paddle.Tensor:
|
||||
|
||||
input_embeddings = self.model.get_input_embeddings(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
image_mask = ids_remove_padding == self.model.image_token_id
|
||||
image_token_num = image_mask.sum()
|
||||
|
||||
video_mask = ids_remove_padding == self.model.video_token_id
|
||||
video_token_num = video_mask.sum()
|
||||
|
||||
# 由于框架只有 image_features,所以目前不支持图片和视频混合
|
||||
# TODO(wangyafeng) 后续考虑支持传入 video_features
|
||||
if image_token_num > 0:
|
||||
input_embeddings[image_mask] = image_features.cast(self.model._dtype)
|
||||
if video_token_num > 0:
|
||||
input_embeddings[video_mask] = image_features.cast(self.model._dtype)
|
||||
|
||||
return input_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor],
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
input_embeddings = self.get_input_embeddings(
|
||||
ids_remove_padding=ids_remove_padding, image_features=image_features
|
||||
)
|
||||
self._input_embeddings.copy_(input_embeddings, False)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_embeddings=self._input_embeddings,
|
||||
ids_remove_padding=ids_remove_padding,
|
||||
image_features=image_features,
|
||||
forward_meta=forward_meta,
|
||||
|
503
tests/ci_use/Qwen2_5_VL/test_Qwen2_5_VL_serving.py
Normal file
503
tests/ci_use/Qwen2_5_VL/test_Qwen2_5_VL_serving.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Read ports from environment variables; use default values if not set
|
||||
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||
|
||||
# List of ports to clean before and after tests
|
||||
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT]
|
||||
|
||||
|
||||
def is_port_open(host: str, port: int, timeout=1.0):
|
||||
"""
|
||||
Check if a TCP port is open on the given host.
|
||||
Returns True if connection succeeds, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def kill_process_on_port(port: int):
|
||||
"""
|
||||
Kill processes that are listening on the given port.
|
||||
Uses `lsof` to find process ids and sends SIGKILL.
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
|
||||
for pid in output.splitlines():
|
||||
os.kill(int(pid), signal.SIGKILL)
|
||||
print(f"Killed process on port {port}, pid={pid}")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
def clean_ports():
|
||||
"""
|
||||
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
|
||||
"""
|
||||
for port in PORTS_TO_CLEAN:
|
||||
kill_process_on_port(port)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
- Cleans ports before tests
|
||||
- Starts the API server as a subprocess
|
||||
- Waits for server port to open (up to 30 seconds)
|
||||
- Tears down server after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
clean_ports()
|
||||
|
||||
model_path = "/ModelData/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
log_path = "server.log"
|
||||
limit_mm_str = json.dumps({"image": 100, "video": 100})
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
# "--tensor-parallel-size",
|
||||
# "2",
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--enable-mm",
|
||||
"--max-model-len",
|
||||
"32768",
|
||||
"--max-num-batched-tokens",
|
||||
"384",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--limit-mm-per-prompt",
|
||||
limit_mm_str,
|
||||
]
|
||||
|
||||
print(cmd)
|
||||
# Start subprocess in new process group
|
||||
with open(log_path, "w") as logfile:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
)
|
||||
|
||||
print(f"Started API server with pid {process.pid}")
|
||||
# Wait up to 10 minutes for API server to be ready
|
||||
for _ in range(10 * 60):
|
||||
if is_port_open("127.0.0.1", FD_API_PORT):
|
||||
print(f"API server is up on port {FD_API_PORT}")
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
print("[TIMEOUT] API server failed to start in 10 minutes. Cleaning up...")
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
except Exception as e:
|
||||
print(f"Failed to kill process group: {e}")
|
||||
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
|
||||
|
||||
yield # Run tests
|
||||
|
||||
print("\n===== Post-test server cleanup... =====")
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
print(f"API server (pid={process.pid}) terminated")
|
||||
except Exception as e:
|
||||
print(f"Failed to terminate API server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_url(request):
|
||||
"""
|
||||
Returns the API endpoint URL for chat completions.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def metrics_url(request):
|
||||
"""
|
||||
Returns the metrics endpoint URL.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers():
|
||||
"""
|
||||
Returns common HTTP request headers.
|
||||
"""
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def consistent_payload():
|
||||
"""
|
||||
Returns a fixed payload for consistency testing,
|
||||
including a fixed random seed and temperature.
|
||||
"""
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"top_p": 0, # fix top_p to reduce randomness
|
||||
"seed": 13, # fixed random seed
|
||||
}
|
||||
|
||||
|
||||
# ==========================
|
||||
# Consistency test for repeated runs with fixed payload
|
||||
# ==========================
|
||||
def test_consistency_between_runs(api_url, headers, consistent_payload):
|
||||
"""
|
||||
Test that result is same as the base result.
|
||||
"""
|
||||
# request
|
||||
resp1 = requests.post(api_url, headers=headers, json=consistent_payload)
|
||||
assert resp1.status_code == 200
|
||||
result1 = resp1.json()
|
||||
content1 = result1["choices"][0]["message"]["content"]
|
||||
file_res_temp = "Qwen2.5-VL-7B-Instruct-temp"
|
||||
f_o = open(file_res_temp, "a")
|
||||
f_o.writelines(content1)
|
||||
f_o.close()
|
||||
|
||||
# base result
|
||||
content2 = "这张图片展示了一群人在进行手工艺活动。前景中有两个孩子和一个成年人,他们似乎在制作或展示某种手工艺品。成年人手里拿着一个扇子,上面有彩色的图案,可能是在指导孩子们如何制作这个扇子。孩子们看起来很专注,正在认真地观察和学习。背景中还有其他人在进行类似的活动,环境看起来像是在一个教室或工作室里。整体氛围显得非常温馨和积极。"
|
||||
|
||||
# Verify that result is same as the base result
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
# ==========================
|
||||
# OpenAI Client Chat Completion Test
|
||||
# ==========================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client():
|
||||
ip = "0.0.0.0"
|
||||
service_http_port = str(FD_API_PORT)
|
||||
client = openai.Client(
|
||||
base_url=f"http://{ip}:{service_http_port}/v1",
|
||||
api_key="EMPTY_API_KEY",
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
# Non-streaming test
|
||||
def test_non_streaming_chat(openai_client):
|
||||
"""Test non-streaming chat functionality with the local service"""
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant.",
|
||||
}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert hasattr(response, "choices")
|
||||
assert len(response.choices) > 0
|
||||
assert hasattr(response.choices[0], "message")
|
||||
assert hasattr(response.choices[0].message, "content")
|
||||
|
||||
|
||||
# Streaming test
|
||||
def test_streaming_chat(openai_client, capsys):
|
||||
"""Test streaming chat functionality with the local service"""
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant.",
|
||||
}, # system不是必需,可选
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "China(Beijing), France(Paris), Australia(Canberra).",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=512,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
for chunk in response:
|
||||
if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
|
||||
output.append(chunk.choices[0].delta.content)
|
||||
assert len(output) > 2
|
||||
|
||||
|
||||
# ==========================
|
||||
# OpenAI Client additional chat/completions test
|
||||
# ==========================
|
||||
|
||||
|
||||
def test_non_streaming_chat_with_return_token_ids(openai_client, capsys):
|
||||
"""
|
||||
Test return_token_ids option in non-streaming chat functionality with the local service
|
||||
"""
|
||||
# 设定 return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": True},
|
||||
stream=False,
|
||||
)
|
||||
assert hasattr(response, "choices")
|
||||
assert len(response.choices) > 0
|
||||
assert hasattr(response.choices[0], "message")
|
||||
assert hasattr(response.choices[0].message, "prompt_token_ids")
|
||||
assert isinstance(response.choices[0].message.prompt_token_ids, list)
|
||||
assert hasattr(response.choices[0].message, "completion_token_ids")
|
||||
assert isinstance(response.choices[0].message.completion_token_ids, list)
|
||||
|
||||
# 不设定 return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": False},
|
||||
stream=False,
|
||||
)
|
||||
assert hasattr(response, "choices")
|
||||
assert len(response.choices) > 0
|
||||
assert hasattr(response.choices[0], "message")
|
||||
assert hasattr(response.choices[0].message, "prompt_token_ids")
|
||||
assert response.choices[0].message.prompt_token_ids is None
|
||||
assert hasattr(response.choices[0].message, "completion_token_ids")
|
||||
assert response.choices[0].message.completion_token_ids is None
|
||||
|
||||
|
||||
def test_streaming_chat_with_return_token_ids(openai_client, capsys):
|
||||
"""
|
||||
Test return_token_ids option in streaming chat functionality with the local service
|
||||
"""
|
||||
# enable return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": True},
|
||||
stream=True,
|
||||
)
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "delta")
|
||||
assert hasattr(chunk.choices[0].delta, "prompt_token_ids")
|
||||
assert hasattr(chunk.choices[0].delta, "completion_token_ids")
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
assert isinstance(chunk.choices[0].delta.prompt_token_ids, list)
|
||||
assert chunk.choices[0].delta.completion_token_ids is None
|
||||
else:
|
||||
assert chunk.choices[0].delta.prompt_token_ids is None
|
||||
assert isinstance(chunk.choices[0].delta.completion_token_ids, list)
|
||||
|
||||
# disable return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": False},
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "delta")
|
||||
assert hasattr(chunk.choices[0].delta, "prompt_token_ids")
|
||||
assert chunk.choices[0].delta.prompt_token_ids is None
|
||||
assert hasattr(chunk.choices[0].delta, "completion_token_ids")
|
||||
assert chunk.choices[0].delta.completion_token_ids is None
|
||||
|
||||
|
||||
def test_profile_reset_block_num():
|
||||
"""测试profile reset_block_num功能,与baseline diff不能超过15%"""
|
||||
log_file = "./log/config.log"
|
||||
baseline = 30000
|
||||
|
||||
if not os.path.exists(log_file):
|
||||
pytest.fail(f"Log file not found: {log_file}")
|
||||
|
||||
with open(log_file, "r") as f:
|
||||
log_lines = f.readlines()
|
||||
|
||||
target_line = None
|
||||
for line in log_lines:
|
||||
if "Reset block num" in line:
|
||||
target_line = line.strip()
|
||||
break
|
||||
|
||||
if target_line is None:
|
||||
pytest.fail("日志中没有Reset block num信息")
|
||||
|
||||
match = re.search(r"total_block_num:(\d+)", target_line)
|
||||
if not match:
|
||||
pytest.fail(f"Failed to extract total_block_num from line: {target_line}")
|
||||
|
||||
try:
|
||||
actual_value = int(match.group(1))
|
||||
except ValueError:
|
||||
pytest.fail(f"Invalid number format: {match.group(1)}")
|
||||
|
||||
lower_bound = baseline * (1 - 0.15)
|
||||
upper_bound = baseline * (1 + 0.15)
|
||||
print(f"Reset total_block_num: {actual_value}. baseline: {baseline}")
|
||||
|
||||
assert lower_bound <= actual_value <= upper_bound, (
|
||||
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
||||
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
||||
)
|
503
tests/e2e/test_Qwen2_5_VL_serving.py
Normal file
503
tests/e2e/test_Qwen2_5_VL_serving.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Read ports from environment variables; use default values if not set
|
||||
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||
|
||||
# List of ports to clean before and after tests
|
||||
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT]
|
||||
|
||||
|
||||
def is_port_open(host: str, port: int, timeout=1.0):
|
||||
"""
|
||||
Check if a TCP port is open on the given host.
|
||||
Returns True if connection succeeds, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def kill_process_on_port(port: int):
|
||||
"""
|
||||
Kill processes that are listening on the given port.
|
||||
Uses `lsof` to find process ids and sends SIGKILL.
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
|
||||
for pid in output.splitlines():
|
||||
os.kill(int(pid), signal.SIGKILL)
|
||||
print(f"Killed process on port {port}, pid={pid}")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
def clean_ports():
|
||||
"""
|
||||
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
|
||||
"""
|
||||
for port in PORTS_TO_CLEAN:
|
||||
kill_process_on_port(port)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
- Cleans ports before tests
|
||||
- Starts the API server as a subprocess
|
||||
- Waits for server port to open (up to 30 seconds)
|
||||
- Tears down server after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
clean_ports()
|
||||
|
||||
model_path = "/ModelData/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
log_path = "server.log"
|
||||
limit_mm_str = json.dumps({"image": 100, "video": 100})
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
# "--tensor-parallel-size",
|
||||
# "2",
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--enable-mm",
|
||||
"--max-model-len",
|
||||
"32768",
|
||||
"--max-num-batched-tokens",
|
||||
"384",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--limit-mm-per-prompt",
|
||||
limit_mm_str,
|
||||
]
|
||||
|
||||
print(cmd)
|
||||
# Start subprocess in new process group
|
||||
with open(log_path, "w") as logfile:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
)
|
||||
|
||||
print(f"Started API server with pid {process.pid}")
|
||||
# Wait up to 10 minutes for API server to be ready
|
||||
for _ in range(10 * 60):
|
||||
if is_port_open("127.0.0.1", FD_API_PORT):
|
||||
print(f"API server is up on port {FD_API_PORT}")
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
print("[TIMEOUT] API server failed to start in 10 minutes. Cleaning up...")
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
except Exception as e:
|
||||
print(f"Failed to kill process group: {e}")
|
||||
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
|
||||
|
||||
yield # Run tests
|
||||
|
||||
print("\n===== Post-test server cleanup... =====")
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
print(f"API server (pid={process.pid}) terminated")
|
||||
except Exception as e:
|
||||
print(f"Failed to terminate API server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_url(request):
|
||||
"""
|
||||
Returns the API endpoint URL for chat completions.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def metrics_url(request):
|
||||
"""
|
||||
Returns the metrics endpoint URL.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers():
|
||||
"""
|
||||
Returns common HTTP request headers.
|
||||
"""
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def consistent_payload():
|
||||
"""
|
||||
Returns a fixed payload for consistency testing,
|
||||
including a fixed random seed and temperature.
|
||||
"""
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"top_p": 0, # fix top_p to reduce randomness
|
||||
"seed": 13, # fixed random seed
|
||||
}
|
||||
|
||||
|
||||
# ==========================
|
||||
# Consistency test for repeated runs with fixed payload
|
||||
# ==========================
|
||||
def test_consistency_between_runs(api_url, headers, consistent_payload):
|
||||
"""
|
||||
Test that result is same as the base result.
|
||||
"""
|
||||
# request
|
||||
resp1 = requests.post(api_url, headers=headers, json=consistent_payload)
|
||||
assert resp1.status_code == 200
|
||||
result1 = resp1.json()
|
||||
content1 = result1["choices"][0]["message"]["content"]
|
||||
file_res_temp = "Qwen2.5-VL-7B-Instruct-temp"
|
||||
f_o = open(file_res_temp, "a")
|
||||
f_o.writelines(content1)
|
||||
f_o.close()
|
||||
|
||||
# base result
|
||||
content2 = "这张图片展示了一群人在进行手工艺活动。前景中有两个孩子和一个成年人,他们似乎在制作或展示某种手工艺品。成年人手里拿着一个扇子,上面有彩色的图案,可能是通过某种方式绘制或涂鸦而成。孩子们看起来很专注,可能是在观察或参与这个过程。\n\n背景中还有其他几个人,其中一个人穿着粉色的衣服,背对着镜头。整个场景看起来像是在一个室内环境中,光线充足,氛围轻松愉快。"
|
||||
|
||||
# Verify that result is same as the base result
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
# ==========================
|
||||
# OpenAI Client Chat Completion Test
|
||||
# ==========================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client():
|
||||
ip = "0.0.0.0"
|
||||
service_http_port = str(FD_API_PORT)
|
||||
client = openai.Client(
|
||||
base_url=f"http://{ip}:{service_http_port}/v1",
|
||||
api_key="EMPTY_API_KEY",
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
# Non-streaming test
|
||||
def test_non_streaming_chat(openai_client):
|
||||
"""Test non-streaming chat functionality with the local service"""
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant.",
|
||||
}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert hasattr(response, "choices")
|
||||
assert len(response.choices) > 0
|
||||
assert hasattr(response.choices[0], "message")
|
||||
assert hasattr(response.choices[0].message, "content")
|
||||
|
||||
|
||||
# Streaming test
|
||||
def test_streaming_chat(openai_client, capsys):
|
||||
"""Test streaming chat functionality with the local service"""
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant.",
|
||||
}, # system不是必需,可选
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "China(Beijing), France(Paris), Australia(Canberra).",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=512,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
for chunk in response:
|
||||
if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
|
||||
output.append(chunk.choices[0].delta.content)
|
||||
assert len(output) > 2
|
||||
|
||||
|
||||
# ==========================
|
||||
# OpenAI Client additional chat/completions test
|
||||
# ==========================
|
||||
|
||||
|
||||
def test_non_streaming_chat_with_return_token_ids(openai_client, capsys):
|
||||
"""
|
||||
Test return_token_ids option in non-streaming chat functionality with the local service
|
||||
"""
|
||||
# 设定 return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": True},
|
||||
stream=False,
|
||||
)
|
||||
assert hasattr(response, "choices")
|
||||
assert len(response.choices) > 0
|
||||
assert hasattr(response.choices[0], "message")
|
||||
assert hasattr(response.choices[0].message, "prompt_token_ids")
|
||||
assert isinstance(response.choices[0].message.prompt_token_ids, list)
|
||||
assert hasattr(response.choices[0].message, "completion_token_ids")
|
||||
assert isinstance(response.choices[0].message.completion_token_ids, list)
|
||||
|
||||
# 不设定 return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": False},
|
||||
stream=False,
|
||||
)
|
||||
assert hasattr(response, "choices")
|
||||
assert len(response.choices) > 0
|
||||
assert hasattr(response.choices[0], "message")
|
||||
assert hasattr(response.choices[0].message, "prompt_token_ids")
|
||||
assert response.choices[0].message.prompt_token_ids is None
|
||||
assert hasattr(response.choices[0].message, "completion_token_ids")
|
||||
assert response.choices[0].message.completion_token_ids is None
|
||||
|
||||
|
||||
def test_streaming_chat_with_return_token_ids(openai_client, capsys):
|
||||
"""
|
||||
Test return_token_ids option in streaming chat functionality with the local service
|
||||
"""
|
||||
# enable return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": True},
|
||||
stream=True,
|
||||
)
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "delta")
|
||||
assert hasattr(chunk.choices[0].delta, "prompt_token_ids")
|
||||
assert hasattr(chunk.choices[0].delta, "completion_token_ids")
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
assert isinstance(chunk.choices[0].delta.prompt_token_ids, list)
|
||||
assert chunk.choices[0].delta.completion_token_ids is None
|
||||
else:
|
||||
assert chunk.choices[0].delta.prompt_token_ids is None
|
||||
assert isinstance(chunk.choices[0].delta.completion_token_ids, list)
|
||||
|
||||
# disable return_token_ids
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant."}, # system不是必需,可选
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "请描述图片内容"},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=1,
|
||||
max_tokens=53,
|
||||
extra_body={"return_token_ids": False},
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
assert hasattr(chunk, "choices")
|
||||
assert len(chunk.choices) > 0
|
||||
assert hasattr(chunk.choices[0], "delta")
|
||||
assert hasattr(chunk.choices[0].delta, "prompt_token_ids")
|
||||
assert chunk.choices[0].delta.prompt_token_ids is None
|
||||
assert hasattr(chunk.choices[0].delta, "completion_token_ids")
|
||||
assert chunk.choices[0].delta.completion_token_ids is None
|
||||
|
||||
|
||||
def test_profile_reset_block_num():
|
||||
"""测试profile reset_block_num功能,与baseline diff不能超过15%"""
|
||||
log_file = "./log/config.log"
|
||||
baseline = 30000
|
||||
|
||||
if not os.path.exists(log_file):
|
||||
pytest.fail(f"Log file not found: {log_file}")
|
||||
|
||||
with open(log_file, "r") as f:
|
||||
log_lines = f.readlines()
|
||||
|
||||
target_line = None
|
||||
for line in log_lines:
|
||||
if "Reset block num" in line:
|
||||
target_line = line.strip()
|
||||
break
|
||||
|
||||
if target_line is None:
|
||||
pytest.fail("日志中没有Reset block num信息")
|
||||
|
||||
match = re.search(r"total_block_num:(\d+)", target_line)
|
||||
if not match:
|
||||
pytest.fail(f"Failed to extract total_block_num from line: {target_line}")
|
||||
|
||||
try:
|
||||
actual_value = int(match.group(1))
|
||||
except ValueError:
|
||||
pytest.fail(f"Invalid number format: {match.group(1)}")
|
||||
|
||||
lower_bound = baseline * (1 - 0.15)
|
||||
upper_bound = baseline * (1 + 0.15)
|
||||
print(f"Reset total_block_num: {actual_value}. baseline: {baseline}")
|
||||
|
||||
assert lower_bound <= actual_value <= upper_bound, (
|
||||
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
||||
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
||||
)
|
Reference in New Issue
Block a user