[Feature] The 45VL supports prompt_token_ids + messages input. (#5148)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support prompt_token_ids + messages

* fix bug

* refact code structure

* support cache mm items

* refact code structure

* delete test cases

* modify unit test

* add unit test

* add unit test

* fix append

* add check for messages
This commit is contained in:
kxz2002
2025-11-25 23:11:44 +08:00
committed by GitHub
parent 66e096d509
commit 2d787590c4
4 changed files with 601 additions and 21 deletions

View File

@@ -671,10 +671,7 @@ class ChatCompletionRequest(BaseModel):
if request_id is not None:
req_dict["request_id"] = request_id
if "prompt_token_ids" in req_dict:
if "messages" in req_dict:
del req_dict["messages"]
else:
if "prompt_token_ids" not in req_dict or not req_dict["prompt_token_ids"]:
# If disable_chat_template is set, then the first message in messages will be used as the prompt.
assert (
len(req_dict["messages"]) > 0

View File

@@ -219,7 +219,13 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.get("prompt"):
if request.get("prompt_token_ids"):
messages = request.get("messages")
if messages:
self._check_mm_limits(messages)
request.setdefault("enable_thinking", True)
outputs = self.ernie4_5_processor.prompt_token_ids2outputs(request)
elif request.get("prompt"):
multimodal_data = request.get("multimodal_data")
if multimodal_data is None:
multimodal_data = {}
@@ -256,7 +262,9 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
self.append_completion_tokens(outputs, request["completion_token_ids"])
outputs = self.pack_outputs(outputs)
request["prompt_token_ids"] = outputs["input_ids"].tolist()
request["prompt_token_ids"] = (
outputs["input_ids"].tolist() if "prompt_token_ids" not in request else request["prompt_token_ids"]
)
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
request["multimodal_inputs"] = outputs

View File

@@ -136,7 +136,9 @@ class DataProcessor:
self.video_end = self.VID_END
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end)
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end)
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
@@ -243,14 +245,7 @@ class DataProcessor:
return outputs
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
def extract_mm_items(self, request: Dict[str, Any]):
messages = parse_chat_messages(request.get("messages"))
mm_items = []
for msg in messages:
@@ -273,6 +268,7 @@ class DataProcessor:
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
dealer = None
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
@@ -295,6 +291,16 @@ class DataProcessor:
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
@@ -329,6 +335,115 @@ class DataProcessor:
return outputs
def prompt_token_ids2outputs(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"mm_positions": [],
"mm_hashes": [],
}
prompt_token_ids = request.get("prompt_token_ids", [])
prompt_token_ids_len = len(prompt_token_ids)
if not request.get("messages"):
outputs["input_ids"].extend(prompt_token_ids)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len)
for i in range(prompt_token_ids_len):
outputs["position_ids"].append([i] * 3)
outputs["cur_position"] += prompt_token_ids_len
return outputs
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
st, image_idx, video_idx = 0, 0, 0
while st < prompt_token_ids_len:
cur_token_id = prompt_token_ids[st]
if cur_token_id == self.image_start_id:
if image_idx >= len(images):
raise ValueError("prompt token ids has more image placeholder than in messages")
# append image_start_id
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
# process placeholder token ids
cur_idx = st
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id:
cur_idx += 1
if cur_idx >= prompt_token_ids_len:
raise ValueError("image token ids not complete")
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
token_len = cur_idx - st
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid, token_len)
else:
self._add_processed_image(image, outputs, uuid, token_len)
image_idx += 1
st = cur_idx
elif cur_token_id == self.video_start_id:
if video_idx >= len(videos):
raise ValueError("prompt token ids has more video placeholder than in messages")
# append video_start_id
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
# process placeholder token ids
cur_idx = st
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id:
cur_idx += 1
if cur_idx >= prompt_token_ids_len:
raise ValueError("video token ids not complete")
video = videos[video_idx]
uuid = video_uuid[video_idx] if video_uuid else None
token_len = cur_idx - st
if not isinstance(video, tuple):
if isinstance(video, dict):
frames = self._load_and_process_video(video["video"], video)
else:
frames = self._load_and_process_video(video, {})
self._add_video(frames, outputs, uuid, token_len)
else:
self._add_processed_video(video, outputs, uuid, token_len)
video_idx += 1
st = cur_idx
else:
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
if image_idx != len(images):
raise ValueError("number of images does not match")
if video_idx != len(videos):
raise ValueError("number of videos does not match")
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx][0]
meta["thw"] = (t, h, w)
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
outputs["input_ids"].append(token_id)
@@ -348,7 +463,7 @@ class DataProcessor:
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
@@ -356,6 +471,8 @@ class DataProcessor:
max_pixels=self.image_max_pixels,
)[1]
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
if token_len and token_len != num_tokens:
raise ValueError("image tokens num not match the size")
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
@@ -383,9 +500,13 @@ class DataProcessor:
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
def _add_processed_image(
self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
if token_len and num_tokens != token_len:
raise ValueError("image tokens num not match the size")
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
@@ -401,7 +522,7 @@ class DataProcessor:
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
@@ -410,6 +531,8 @@ class DataProcessor:
)[1]
num_frames = len(frames)
num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
if token_len and num_tokens != token_len:
raise ValueError("video tokens num not match the size")
pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
ret = self.image_preprocessor.preprocess(
@@ -438,9 +561,13 @@ class DataProcessor:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
def _add_processed_video(
self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
if token_len and num_tokens != token_len:
raise ValueError("video tokens num not match the size")
t, h, w = meta["thw"]
outputs["images"].append(frames)

View File

@@ -1,7 +1,15 @@
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
from fastdeploy.input.ernie4_5_vl_processor.image_preprocessor.image_preprocessor_adaptive import (
AdaptiveImageProcessor,
)
from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
from fastdeploy.input.utils import IDS_TYPE_FLAG
class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
@@ -77,7 +85,7 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
@@ -93,7 +101,7 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
@@ -101,7 +109,7 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
@@ -111,6 +119,446 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"enable_thinking": False},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
class TestDataProcessorTargetMethods(unittest.TestCase):
def setUp(self):
self.mock_tokenizer = MagicMock(spec=Ernie4_5Tokenizer)
self.mock_tokenizer.ignored_index = -100
self.mock_tokenizer.convert_tokens_to_ids.side_effect = self._mock_convert_tokens_to_ids
self.mock_tokenizer.chat_template = "mock_template"
self.mock_tokenizer.apply_chat_template.return_value = "User: Hello<|image@placeholder|>"
def mock_load_tokenizer(dp_instance):
dp_instance.tokenizer = self.mock_tokenizer
with patch.object(DataProcessor, "_load_tokenizer", side_effect=mock_load_tokenizer, autospec=True):
with patch.object(AdaptiveImageProcessor, "from_pretrained") as mock_image_preprocessor:
mock_image_preprocessor.return_value = MagicMock()
self.data_processor = DataProcessor(
tokenizer_name="mock_tokenizer",
image_preprocessor_name="mock_image_preprocessor",
enable_processor_cache=False,
)
self.data_processor.image_patch_id = 1001
self.data_processor.image_start_id = 1002
self.data_processor.image_end_id = 1003
self.data_processor.video_start_id = 1004
self.data_processor.video_end_id = 1005
self.data_processor.role_prefixes = {"user": "User: ", "assistant": "Assistant: "}
self.data_processor.enable_processor_cache = False
self.data_processor.extract_mm_items = MagicMock(return_value=([], [], [], [], None, [], []))
def _mock_convert_tokens_to_ids(self, token):
token_id_map = {
"<|begin_of_sentence|>": 101,
"<|end_of_sentence|>": 102,
"</s>": 103,
"<|IMAGE_PLACEHOLDER|>": 1001,
"<|IMAGE_START|>": 1002,
"<|IMAGE_END|>": 1003,
"<|VIDEO_START|>": 1004,
"<|VIDEO_END|>": 1005,
}
return token_id_map.get(token, 999)
def test_prompt_token_ids2outputs_only_prompt_token_ids(self):
test_prompt_token_ids = [101, 999, 998, 997, 102]
request = {
"prompt_token_ids": test_prompt_token_ids,
}
outputs = self.data_processor.prompt_token_ids2outputs(request)
prompt_len = len(test_prompt_token_ids)
self.assertEqual(
outputs["input_ids"],
test_prompt_token_ids,
f"input_ids 不匹配:实际{outputs['input_ids']},预期[{test_prompt_token_ids}]",
)
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
self.assertEqual(outputs["position_ids"], expected_position_ids)
self.assertEqual(outputs["cur_position"], prompt_len)
self.assertEqual(len(outputs["images"]), 0)
self.assertEqual(len(outputs["grid_thw"]), 0)
self.assertEqual(len(outputs["mm_positions"]), 0)
self.assertEqual(len(outputs["mm_hashes"]), 0)
self.assertEqual(outputs["video_cnt"], 0)
self.assertEqual(outputs["num_input_image_tokens"], 0)
self.assertEqual(outputs["num_input_video_tokens"], 0)
def test_prompt_token_ids2outputs_with_messages_no_mm(self):
test_prompt_token_ids = [101, 999, 998, 997, 102]
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [{"role": "user", "content": "Hello World"}],
}
self.data_processor.extract_mm_items.return_value = ([], [], [], [], None, [], [])
outputs = self.data_processor.prompt_token_ids2outputs(request)
prompt_len = len(test_prompt_token_ids)
self.assertEqual(outputs["input_ids"], test_prompt_token_ids)
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
self.assertEqual(outputs["position_ids"], expected_position_ids)
self.assertEqual(outputs["cur_position"], prompt_len)
self.assertEqual(len(outputs["images"]), 0)
self.assertEqual(outputs["video_cnt"], 0)
self.assertEqual(outputs["num_input_image_tokens"], 0)
def test_prompt_token_ids2outputs_add_image(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
mock_img = MagicMock()
mock_img.height = 224
mock_img.width = 224
mock_img.convert.return_value = mock_img
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img}],
)
mock_resize = (None, (2, 4))
self.data_processor.image_preprocessor.get_smarted_resize.return_value = mock_resize
mock_preprocess = {"pixel_values": np.random.randn(1, 16, 16, 3), "image_grid_thw": np.array([[2, 4]])}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
# self.data_processor._compute_3d_positions = MagicMock(return_value=[[i]*3 for i in range(4)])
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 6)
self.assertEqual(outputs["cur_position"], 6)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(outputs["num_input_image_tokens"], 2)
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(len(outputs["mm_hashes"]), 1)
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 1)
def test_prompt_token_ids2outputs_add_processed_image(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
mock_img_data = np.random.randn(8, 28, 28)
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img_cache],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img_cache}],
)
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["image"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 20)
self.assertEqual(outputs["cur_position"], 8)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(outputs["mm_hashes"][0], "img_uuid")
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 1)
def test_prompt_token_ids2outputs_add_video(self):
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
mock_frame1 = MagicMock()
mock_frame1.height = 224
mock_frame1.width = 224
mock_frame1.convert.return_value = mock_frame1
mock_frame2 = MagicMock()
mock_frame2.height = 224
mock_frame2.width = 224
mock_frame2.convert.return_value = mock_frame2
frames = [mock_frame1, mock_frame2]
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[frames],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": frames}],
)
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
patches_h, patches_w = 4, 4
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
mock_preprocess = {
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 8)
self.assertEqual(outputs["cur_position"], 6)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 2)
self.assertEqual(outputs["num_input_video_tokens"], 4)
def test_prompt_token_ids2outputs_add_processed_video(self):
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
t, h, w = 2, 4, 4
spatial_conv_size = self.data_processor.spatial_conv_size
temporal_conv_size = self.data_processor.temporal_conv_size
token_per_frame = (h // spatial_conv_size) * (w // spatial_conv_size)
num_tokens = (t // temporal_conv_size) * token_per_frame
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[mock_frames_cache],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": mock_frames_cache}],
)
outputs = self.data_processor.prompt_token_ids2outputs(request)
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
self.assertEqual(
outputs["token_type_ids"],
[
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["video"],
IDS_TYPE_FLAG["text"],
IDS_TYPE_FLAG["text"],
],
)
self.assertEqual(len(outputs["position_ids"]), 8)
self.assertEqual(outputs["cur_position"], 6)
self.assertEqual(len(outputs["images"]), 1)
self.assertIsNotNone(outputs["images"][0])
self.assertEqual(len(outputs["mm_positions"]), 1)
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
self.assertEqual(len(outputs["grid_thw"]), 1)
self.assertEqual(len(outputs["image_type_ids"]), 2)
def test_prompt_token_ids2outputs_add_image_token_len_mismatch(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1001, 1003, 102]
mock_img = MagicMock()
mock_img.height = 224
mock_img.width = 224
mock_img.convert.return_value = mock_img
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img}],
)
patches_h, patches_w = 8, 8
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
mock_preprocess = {
"pixel_values": np.random.randn(1, patches_h, patches_w, 3),
"image_grid_thw": np.array([[patches_h, patches_w]]),
}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("image tokens num not match the size", str(ctx.exception))
def test_prompt_token_ids2outputs_add_processed_image_token_len_mismatch(self):
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
spatial_conv_size = self.data_processor.spatial_conv_size
num_tokens = 4
mock_img_data = np.random.randn(num_tokens * (spatial_conv_size**2), 28, 28)
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[mock_img_cache],
[],
["img_uuid"],
[],
None,
[],
[{"type": "image", "data": mock_img_cache}],
)
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("image tokens num not match the size", str(ctx.exception))
def test_prompt_token_ids2outputs_add_video_token_len_mismatch(self):
test_prompt_token_ids = [101, 1004, 1001, 1001, 1005, 102]
mock_frame1 = MagicMock()
mock_frame1.height = 224
mock_frame1.width = 224
mock_frame1.convert.return_value = mock_frame1
mock_frame2 = MagicMock()
mock_frame2.height = 224
mock_frame2.width = 224
mock_frame2.convert.return_value = mock_frame2
frames = [mock_frame1, mock_frame2]
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[frames],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": frames}],
)
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
patches_h, patches_w = 8, 8
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
mock_preprocess = {
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
}
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("video tokens num not match the size", str(ctx.exception))
def test_prompt_token_ids2outputs_add_processed_video_token_len_mismatch(self):
test_prompt_token_ids = [101, 1004, 1001, 1005, 102]
t, h, w = 2, 8, 8
spatial_conv_size = self.data_processor.spatial_conv_size
temporal_conv_size = self.data_processor.temporal_conv_size
num_tokens = 4
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
request = {
"prompt_token_ids": test_prompt_token_ids,
"messages": [
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
],
}
self.data_processor.extract_mm_items.return_value = (
[],
[mock_frames_cache],
[],
["vid_uuid"],
None,
[],
[{"type": "video", "data": mock_frames_cache}],
)
with self.assertRaises(ValueError) as ctx:
self.data_processor.prompt_token_ids2outputs(request)
self.assertIn("video tokens num not match the size", str(ctx.exception))
if __name__ == "__main__":
unittest.main()