mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] The 45VL supports prompt_token_ids + messages input. (#5148)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* support prompt_token_ids + messages * fix bug * refact code structure * support cache mm items * refact code structure * delete test cases * modify unit test * add unit test * add unit test * fix append * add check for messages
This commit is contained in:
@@ -671,10 +671,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
if request_id is not None:
|
||||
req_dict["request_id"] = request_id
|
||||
|
||||
if "prompt_token_ids" in req_dict:
|
||||
if "messages" in req_dict:
|
||||
del req_dict["messages"]
|
||||
else:
|
||||
if "prompt_token_ids" not in req_dict or not req_dict["prompt_token_ids"]:
|
||||
# If disable_chat_template is set, then the first message in messages will be used as the prompt.
|
||||
assert (
|
||||
len(req_dict["messages"]) > 0
|
||||
|
||||
@@ -219,7 +219,13 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
if request.get("prompt"):
|
||||
if request.get("prompt_token_ids"):
|
||||
messages = request.get("messages")
|
||||
if messages:
|
||||
self._check_mm_limits(messages)
|
||||
request.setdefault("enable_thinking", True)
|
||||
outputs = self.ernie4_5_processor.prompt_token_ids2outputs(request)
|
||||
elif request.get("prompt"):
|
||||
multimodal_data = request.get("multimodal_data")
|
||||
if multimodal_data is None:
|
||||
multimodal_data = {}
|
||||
@@ -256,7 +262,9 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
self.append_completion_tokens(outputs, request["completion_token_ids"])
|
||||
|
||||
outputs = self.pack_outputs(outputs)
|
||||
request["prompt_token_ids"] = outputs["input_ids"].tolist()
|
||||
request["prompt_token_ids"] = (
|
||||
outputs["input_ids"].tolist() if "prompt_token_ids" not in request else request["prompt_token_ids"]
|
||||
)
|
||||
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
|
||||
request["multimodal_inputs"] = outputs
|
||||
|
||||
|
||||
@@ -136,7 +136,9 @@ class DataProcessor:
|
||||
self.video_end = self.VID_END
|
||||
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
|
||||
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
|
||||
self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end)
|
||||
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
|
||||
self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end)
|
||||
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
|
||||
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
|
||||
|
||||
@@ -243,14 +245,7 @@ class DataProcessor:
|
||||
|
||||
return outputs
|
||||
|
||||
def request2ids(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
"""
|
||||
Convert chat messages into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
|
||||
def extract_mm_items(self, request: Dict[str, Any]):
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
mm_items = []
|
||||
for msg in messages:
|
||||
@@ -273,6 +268,7 @@ class DataProcessor:
|
||||
if len(missing_hashes) > 0 and not self.enable_processor_cache:
|
||||
raise ValueError("Missing items cannot be retrieved without processor cache.")
|
||||
|
||||
dealer = None
|
||||
if self.enable_processor_cache:
|
||||
context = zmq.Context()
|
||||
dealer = context.socket(zmq.DEALER)
|
||||
@@ -295,6 +291,16 @@ class DataProcessor:
|
||||
video_uuid.append(item["uuid"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
|
||||
return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items
|
||||
|
||||
def request2ids(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
"""
|
||||
Convert chat messages into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
|
||||
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat template.")
|
||||
@@ -329,6 +335,115 @@ class DataProcessor:
|
||||
|
||||
return outputs
|
||||
|
||||
def prompt_token_ids2outputs(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"video_cnt": 0,
|
||||
"num_input_image_tokens": 0,
|
||||
"num_input_video_tokens": 0,
|
||||
"mm_positions": [],
|
||||
"mm_hashes": [],
|
||||
}
|
||||
prompt_token_ids = request.get("prompt_token_ids", [])
|
||||
prompt_token_ids_len = len(prompt_token_ids)
|
||||
if not request.get("messages"):
|
||||
outputs["input_ids"].extend(prompt_token_ids)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len)
|
||||
for i in range(prompt_token_ids_len):
|
||||
outputs["position_ids"].append([i] * 3)
|
||||
outputs["cur_position"] += prompt_token_ids_len
|
||||
return outputs
|
||||
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
|
||||
st, image_idx, video_idx = 0, 0, 0
|
||||
while st < prompt_token_ids_len:
|
||||
cur_token_id = prompt_token_ids[st]
|
||||
if cur_token_id == self.image_start_id:
|
||||
if image_idx >= len(images):
|
||||
raise ValueError("prompt token ids has more image placeholder than in messages")
|
||||
# append image_start_id
|
||||
outputs["input_ids"].extend([cur_token_id])
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
|
||||
outputs["position_ids"].append([outputs["cur_position"]] * 3)
|
||||
outputs["cur_position"] += 1
|
||||
st += 1
|
||||
# process placeholder token ids
|
||||
cur_idx = st
|
||||
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id:
|
||||
cur_idx += 1
|
||||
if cur_idx >= prompt_token_ids_len:
|
||||
raise ValueError("image token ids not complete")
|
||||
image = images[image_idx]
|
||||
uuid = image_uuid[image_idx] if image_uuid else None
|
||||
token_len = cur_idx - st
|
||||
if not isinstance(image, tuple):
|
||||
self._add_image(image, outputs, uuid, token_len)
|
||||
else:
|
||||
self._add_processed_image(image, outputs, uuid, token_len)
|
||||
image_idx += 1
|
||||
st = cur_idx
|
||||
elif cur_token_id == self.video_start_id:
|
||||
if video_idx >= len(videos):
|
||||
raise ValueError("prompt token ids has more video placeholder than in messages")
|
||||
# append video_start_id
|
||||
outputs["input_ids"].extend([cur_token_id])
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
|
||||
outputs["position_ids"].append([outputs["cur_position"]] * 3)
|
||||
outputs["cur_position"] += 1
|
||||
st += 1
|
||||
# process placeholder token ids
|
||||
cur_idx = st
|
||||
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id:
|
||||
cur_idx += 1
|
||||
if cur_idx >= prompt_token_ids_len:
|
||||
raise ValueError("video token ids not complete")
|
||||
video = videos[video_idx]
|
||||
uuid = video_uuid[video_idx] if video_uuid else None
|
||||
token_len = cur_idx - st
|
||||
if not isinstance(video, tuple):
|
||||
if isinstance(video, dict):
|
||||
frames = self._load_and_process_video(video["video"], video)
|
||||
else:
|
||||
frames = self._load_and_process_video(video, {})
|
||||
self._add_video(frames, outputs, uuid, token_len)
|
||||
else:
|
||||
self._add_processed_video(video, outputs, uuid, token_len)
|
||||
video_idx += 1
|
||||
st = cur_idx
|
||||
else:
|
||||
outputs["input_ids"].extend([cur_token_id])
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
|
||||
outputs["position_ids"].append([outputs["cur_position"]] * 3)
|
||||
outputs["cur_position"] += 1
|
||||
st += 1
|
||||
if image_idx != len(images):
|
||||
raise ValueError("number of images does not match")
|
||||
if video_idx != len(videos):
|
||||
raise ValueError("number of videos does not match")
|
||||
|
||||
if self.enable_processor_cache:
|
||||
missing_idx = set(missing_idx)
|
||||
hashes_to_cache, items_to_cache = [], []
|
||||
for idx in range(len(mm_items)):
|
||||
if idx in missing_idx:
|
||||
continue
|
||||
meta = {}
|
||||
t, h, w = outputs["grid_thw"][idx][0]
|
||||
meta["thw"] = (t, h, w)
|
||||
hashes_to_cache.append(outputs["mm_hashes"][idx])
|
||||
items_to_cache.append((outputs["images"][idx], meta))
|
||||
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
|
||||
|
||||
return outputs
|
||||
|
||||
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
|
||||
token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
|
||||
outputs["input_ids"].append(token_id)
|
||||
@@ -348,7 +463,7 @@ class DataProcessor:
|
||||
outputs["position_ids"].append([start + i] * 3)
|
||||
outputs["cur_position"] += len(tokens)
|
||||
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
img.height,
|
||||
img.width,
|
||||
@@ -356,6 +471,8 @@ class DataProcessor:
|
||||
max_pixels=self.image_max_pixels,
|
||||
)[1]
|
||||
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
|
||||
if token_len and token_len != num_tokens:
|
||||
raise ValueError("image tokens num not match the size")
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
@@ -383,9 +500,13 @@ class DataProcessor:
|
||||
outputs["grid_thw"].append(ret["image_grid_thw"])
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
def _add_processed_image(
|
||||
self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
|
||||
) -> None:
|
||||
img, meta = img_cache
|
||||
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
|
||||
if token_len and num_tokens != token_len:
|
||||
raise ValueError("image tokens num not match the size")
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
@@ -401,7 +522,7 @@ class DataProcessor:
|
||||
outputs["grid_thw"].append(np.array([[1, h, w]]))
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
frames[0].height,
|
||||
frames[0].width,
|
||||
@@ -410,6 +531,8 @@ class DataProcessor:
|
||||
)[1]
|
||||
num_frames = len(frames)
|
||||
num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
|
||||
if token_len and num_tokens != token_len:
|
||||
raise ValueError("video tokens num not match the size")
|
||||
|
||||
pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
||||
ret = self.image_preprocessor.preprocess(
|
||||
@@ -438,9 +561,13 @@ class DataProcessor:
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
def _add_processed_video(
|
||||
self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
|
||||
) -> None:
|
||||
frames, meta = frames_cache
|
||||
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
|
||||
if token_len and num_tokens != token_len:
|
||||
raise ValueError("video tokens num not match the size")
|
||||
|
||||
t, h, w = meta["thw"]
|
||||
outputs["images"].append(frames)
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
|
||||
from fastdeploy.input.ernie4_5_vl_processor.image_preprocessor.image_preprocessor_adaptive import (
|
||||
AdaptiveImageProcessor,
|
||||
)
|
||||
from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
|
||||
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
||||
|
||||
|
||||
class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
@@ -77,7 +85,7 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
@@ -93,7 +101,7 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
@@ -101,7 +109,7 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
@@ -111,6 +119,446 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
|
||||
|
||||
class TestDataProcessorTargetMethods(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_tokenizer = MagicMock(spec=Ernie4_5Tokenizer)
|
||||
self.mock_tokenizer.ignored_index = -100
|
||||
self.mock_tokenizer.convert_tokens_to_ids.side_effect = self._mock_convert_tokens_to_ids
|
||||
self.mock_tokenizer.chat_template = "mock_template"
|
||||
self.mock_tokenizer.apply_chat_template.return_value = "User: Hello<|image@placeholder|>"
|
||||
|
||||
def mock_load_tokenizer(dp_instance):
|
||||
dp_instance.tokenizer = self.mock_tokenizer
|
||||
|
||||
with patch.object(DataProcessor, "_load_tokenizer", side_effect=mock_load_tokenizer, autospec=True):
|
||||
with patch.object(AdaptiveImageProcessor, "from_pretrained") as mock_image_preprocessor:
|
||||
mock_image_preprocessor.return_value = MagicMock()
|
||||
self.data_processor = DataProcessor(
|
||||
tokenizer_name="mock_tokenizer",
|
||||
image_preprocessor_name="mock_image_preprocessor",
|
||||
enable_processor_cache=False,
|
||||
)
|
||||
self.data_processor.image_patch_id = 1001
|
||||
self.data_processor.image_start_id = 1002
|
||||
self.data_processor.image_end_id = 1003
|
||||
self.data_processor.video_start_id = 1004
|
||||
self.data_processor.video_end_id = 1005
|
||||
self.data_processor.role_prefixes = {"user": "User: ", "assistant": "Assistant: "}
|
||||
self.data_processor.enable_processor_cache = False
|
||||
self.data_processor.extract_mm_items = MagicMock(return_value=([], [], [], [], None, [], []))
|
||||
|
||||
def _mock_convert_tokens_to_ids(self, token):
|
||||
token_id_map = {
|
||||
"<|begin_of_sentence|>": 101,
|
||||
"<|end_of_sentence|>": 102,
|
||||
"</s>": 103,
|
||||
"<|IMAGE_PLACEHOLDER|>": 1001,
|
||||
"<|IMAGE_START|>": 1002,
|
||||
"<|IMAGE_END|>": 1003,
|
||||
"<|VIDEO_START|>": 1004,
|
||||
"<|VIDEO_END|>": 1005,
|
||||
}
|
||||
return token_id_map.get(token, 999)
|
||||
|
||||
def test_prompt_token_ids2outputs_only_prompt_token_ids(self):
|
||||
test_prompt_token_ids = [101, 999, 998, 997, 102]
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
}
|
||||
|
||||
outputs = self.data_processor.prompt_token_ids2outputs(request)
|
||||
|
||||
prompt_len = len(test_prompt_token_ids)
|
||||
|
||||
self.assertEqual(
|
||||
outputs["input_ids"],
|
||||
test_prompt_token_ids,
|
||||
f"input_ids 不匹配:实际{outputs['input_ids']},预期[{test_prompt_token_ids}]",
|
||||
)
|
||||
|
||||
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
|
||||
|
||||
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
|
||||
self.assertEqual(outputs["position_ids"], expected_position_ids)
|
||||
|
||||
self.assertEqual(outputs["cur_position"], prompt_len)
|
||||
|
||||
self.assertEqual(len(outputs["images"]), 0)
|
||||
self.assertEqual(len(outputs["grid_thw"]), 0)
|
||||
self.assertEqual(len(outputs["mm_positions"]), 0)
|
||||
self.assertEqual(len(outputs["mm_hashes"]), 0)
|
||||
self.assertEqual(outputs["video_cnt"], 0)
|
||||
self.assertEqual(outputs["num_input_image_tokens"], 0)
|
||||
self.assertEqual(outputs["num_input_video_tokens"], 0)
|
||||
|
||||
def test_prompt_token_ids2outputs_with_messages_no_mm(self):
|
||||
test_prompt_token_ids = [101, 999, 998, 997, 102]
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [{"role": "user", "content": "Hello World"}],
|
||||
}
|
||||
|
||||
self.data_processor.extract_mm_items.return_value = ([], [], [], [], None, [], [])
|
||||
|
||||
outputs = self.data_processor.prompt_token_ids2outputs(request)
|
||||
|
||||
prompt_len = len(test_prompt_token_ids)
|
||||
|
||||
self.assertEqual(outputs["input_ids"], test_prompt_token_ids)
|
||||
|
||||
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
|
||||
|
||||
expected_position_ids = [[i] * 3 for i in range(prompt_len)]
|
||||
self.assertEqual(outputs["position_ids"], expected_position_ids)
|
||||
|
||||
self.assertEqual(outputs["cur_position"], prompt_len)
|
||||
|
||||
self.assertEqual(len(outputs["images"]), 0)
|
||||
self.assertEqual(outputs["video_cnt"], 0)
|
||||
self.assertEqual(outputs["num_input_image_tokens"], 0)
|
||||
|
||||
def test_prompt_token_ids2outputs_add_image(self):
|
||||
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
|
||||
mock_img = MagicMock()
|
||||
mock_img.height = 224
|
||||
mock_img.width = 224
|
||||
mock_img.convert.return_value = mock_img
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[mock_img],
|
||||
[],
|
||||
["img_uuid"],
|
||||
[],
|
||||
None,
|
||||
[],
|
||||
[{"type": "image", "data": mock_img}],
|
||||
)
|
||||
mock_resize = (None, (2, 4))
|
||||
self.data_processor.image_preprocessor.get_smarted_resize.return_value = mock_resize
|
||||
mock_preprocess = {"pixel_values": np.random.randn(1, 16, 16, 3), "image_grid_thw": np.array([[2, 4]])}
|
||||
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
|
||||
# self.data_processor._compute_3d_positions = MagicMock(return_value=[[i]*3 for i in range(4)])
|
||||
outputs = self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
|
||||
self.assertEqual(
|
||||
outputs["token_type_ids"],
|
||||
[
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["image"],
|
||||
IDS_TYPE_FLAG["image"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(len(outputs["position_ids"]), 6)
|
||||
self.assertEqual(outputs["cur_position"], 6)
|
||||
self.assertEqual(len(outputs["images"]), 1)
|
||||
self.assertIsNotNone(outputs["images"][0])
|
||||
self.assertEqual(outputs["num_input_image_tokens"], 2)
|
||||
self.assertEqual(len(outputs["mm_positions"]), 1)
|
||||
self.assertEqual(len(outputs["mm_hashes"]), 1)
|
||||
self.assertEqual(len(outputs["grid_thw"]), 1)
|
||||
self.assertEqual(len(outputs["image_type_ids"]), 1)
|
||||
|
||||
def test_prompt_token_ids2outputs_add_processed_image(self):
|
||||
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
|
||||
mock_img_data = np.random.randn(8, 28, 28)
|
||||
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[mock_img_cache],
|
||||
[],
|
||||
["img_uuid"],
|
||||
[],
|
||||
None,
|
||||
[],
|
||||
[{"type": "image", "data": mock_img_cache}],
|
||||
)
|
||||
outputs = self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102])
|
||||
self.assertEqual(
|
||||
outputs["token_type_ids"],
|
||||
[
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["image"],
|
||||
IDS_TYPE_FLAG["image"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(len(outputs["position_ids"]), 20)
|
||||
self.assertEqual(outputs["cur_position"], 8)
|
||||
self.assertEqual(len(outputs["images"]), 1)
|
||||
self.assertIsNotNone(outputs["images"][0])
|
||||
self.assertEqual(len(outputs["mm_positions"]), 1)
|
||||
self.assertEqual(outputs["mm_hashes"][0], "img_uuid")
|
||||
self.assertEqual(len(outputs["grid_thw"]), 1)
|
||||
self.assertEqual(len(outputs["image_type_ids"]), 1)
|
||||
|
||||
def test_prompt_token_ids2outputs_add_video(self):
|
||||
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
|
||||
mock_frame1 = MagicMock()
|
||||
mock_frame1.height = 224
|
||||
mock_frame1.width = 224
|
||||
mock_frame1.convert.return_value = mock_frame1
|
||||
mock_frame2 = MagicMock()
|
||||
mock_frame2.height = 224
|
||||
mock_frame2.width = 224
|
||||
mock_frame2.convert.return_value = mock_frame2
|
||||
frames = [mock_frame1, mock_frame2]
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[],
|
||||
[frames],
|
||||
[],
|
||||
["vid_uuid"],
|
||||
None,
|
||||
[],
|
||||
[{"type": "video", "data": frames}],
|
||||
)
|
||||
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
|
||||
patches_h, patches_w = 4, 4
|
||||
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
|
||||
mock_preprocess = {
|
||||
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
|
||||
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
|
||||
}
|
||||
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
|
||||
outputs = self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
|
||||
self.assertEqual(
|
||||
outputs["token_type_ids"],
|
||||
[
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(len(outputs["position_ids"]), 8)
|
||||
self.assertEqual(outputs["cur_position"], 6)
|
||||
self.assertEqual(len(outputs["images"]), 1)
|
||||
self.assertIsNotNone(outputs["images"][0])
|
||||
self.assertEqual(len(outputs["mm_positions"]), 1)
|
||||
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
|
||||
self.assertEqual(len(outputs["grid_thw"]), 1)
|
||||
self.assertEqual(len(outputs["image_type_ids"]), 2)
|
||||
self.assertEqual(outputs["num_input_video_tokens"], 4)
|
||||
|
||||
def test_prompt_token_ids2outputs_add_processed_video(self):
|
||||
test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]
|
||||
t, h, w = 2, 4, 4
|
||||
spatial_conv_size = self.data_processor.spatial_conv_size
|
||||
temporal_conv_size = self.data_processor.temporal_conv_size
|
||||
token_per_frame = (h // spatial_conv_size) * (w // spatial_conv_size)
|
||||
num_tokens = (t // temporal_conv_size) * token_per_frame
|
||||
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
|
||||
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[],
|
||||
[mock_frames_cache],
|
||||
[],
|
||||
["vid_uuid"],
|
||||
None,
|
||||
[],
|
||||
[{"type": "video", "data": mock_frames_cache}],
|
||||
)
|
||||
outputs = self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102])
|
||||
self.assertEqual(
|
||||
outputs["token_type_ids"],
|
||||
[
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["video"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
IDS_TYPE_FLAG["text"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(len(outputs["position_ids"]), 8)
|
||||
self.assertEqual(outputs["cur_position"], 6)
|
||||
self.assertEqual(len(outputs["images"]), 1)
|
||||
self.assertIsNotNone(outputs["images"][0])
|
||||
self.assertEqual(len(outputs["mm_positions"]), 1)
|
||||
self.assertEqual(outputs["mm_hashes"][0], "vid_uuid")
|
||||
self.assertEqual(len(outputs["grid_thw"]), 1)
|
||||
self.assertEqual(len(outputs["image_type_ids"]), 2)
|
||||
|
||||
def test_prompt_token_ids2outputs_add_image_token_len_mismatch(self):
|
||||
test_prompt_token_ids = [101, 1002, 1001, 1001, 1001, 1003, 102]
|
||||
mock_img = MagicMock()
|
||||
mock_img.height = 224
|
||||
mock_img.width = 224
|
||||
mock_img.convert.return_value = mock_img
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[mock_img],
|
||||
[],
|
||||
["img_uuid"],
|
||||
[],
|
||||
None,
|
||||
[],
|
||||
[{"type": "image", "data": mock_img}],
|
||||
)
|
||||
patches_h, patches_w = 8, 8
|
||||
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
|
||||
mock_preprocess = {
|
||||
"pixel_values": np.random.randn(1, patches_h, patches_w, 3),
|
||||
"image_grid_thw": np.array([[patches_h, patches_w]]),
|
||||
}
|
||||
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertIn("image tokens num not match the size", str(ctx.exception))
|
||||
|
||||
def test_prompt_token_ids2outputs_add_processed_image_token_len_mismatch(self):
|
||||
test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102]
|
||||
spatial_conv_size = self.data_processor.spatial_conv_size
|
||||
num_tokens = 4
|
||||
mock_img_data = np.random.randn(num_tokens * (spatial_conv_size**2), 28, 28)
|
||||
mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)})
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[mock_img_cache],
|
||||
[],
|
||||
["img_uuid"],
|
||||
[],
|
||||
None,
|
||||
[],
|
||||
[{"type": "image", "data": mock_img_cache}],
|
||||
)
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertIn("image tokens num not match the size", str(ctx.exception))
|
||||
|
||||
def test_prompt_token_ids2outputs_add_video_token_len_mismatch(self):
|
||||
test_prompt_token_ids = [101, 1004, 1001, 1001, 1005, 102]
|
||||
mock_frame1 = MagicMock()
|
||||
mock_frame1.height = 224
|
||||
mock_frame1.width = 224
|
||||
mock_frame1.convert.return_value = mock_frame1
|
||||
mock_frame2 = MagicMock()
|
||||
mock_frame2.height = 224
|
||||
mock_frame2.width = 224
|
||||
mock_frame2.convert.return_value = mock_frame2
|
||||
frames = [mock_frame1, mock_frame2]
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[],
|
||||
[frames],
|
||||
[],
|
||||
["vid_uuid"],
|
||||
None,
|
||||
[],
|
||||
[{"type": "video", "data": frames}],
|
||||
)
|
||||
self.data_processor._load_and_process_video = MagicMock(return_value=frames)
|
||||
patches_h, patches_w = 8, 8
|
||||
self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w))
|
||||
mock_preprocess = {
|
||||
"pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3),
|
||||
"video_grid_thw": np.array([[patches_h, patches_w]] * 2),
|
||||
}
|
||||
self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertIn("video tokens num not match the size", str(ctx.exception))
|
||||
|
||||
def test_prompt_token_ids2outputs_add_processed_video_token_len_mismatch(self):
|
||||
test_prompt_token_ids = [101, 1004, 1001, 1005, 102]
|
||||
t, h, w = 2, 8, 8
|
||||
spatial_conv_size = self.data_processor.spatial_conv_size
|
||||
temporal_conv_size = self.data_processor.temporal_conv_size
|
||||
|
||||
num_tokens = 4
|
||||
mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28)
|
||||
mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)})
|
||||
request = {
|
||||
"prompt_token_ids": test_prompt_token_ids,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]}
|
||||
],
|
||||
}
|
||||
self.data_processor.extract_mm_items.return_value = (
|
||||
[],
|
||||
[mock_frames_cache],
|
||||
[],
|
||||
["vid_uuid"],
|
||||
None,
|
||||
[],
|
||||
[{"type": "video", "data": mock_frames_cache}],
|
||||
)
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
self.data_processor.prompt_token_ids2outputs(request)
|
||||
self.assertIn("video tokens num not match the size", str(ctx.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user