mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* init * update code * fix code style & disable thinking * adapt for common_engine.update_mm_requests_chunk_size * use 3d rope * use flash_attn_unpadded * opt siglip * update to be compatible with the latest codebase * fix typo * optim OCR performance * fix bug * fix bug * fix bug * fix bug * normlize name * modify xpu rope * revert logger * fix bug * fix bug * fix bug * support default_v1 * optim performance * fix bug --------- Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
276 lines
9.9 KiB
Python
276 lines
9.9 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
"""Image processor class for Keye."""
|
|
|
|
# TODO: Support videos
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
from paddleformers.transformers.feature_extraction_utils import BatchFeature
|
|
from paddleformers.transformers.image_processing_utils import BaseImageProcessor
|
|
from paddleformers.transformers.image_utils import (
|
|
ImageInput,
|
|
is_valid_image,
|
|
make_list_of_images,
|
|
to_numpy_array,
|
|
)
|
|
|
|
_OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
|
_OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
|
|
|
|
|
def make_batched_images(images) -> List[List[ImageInput]]:
|
|
"""
|
|
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
|
|
|
Args:
|
|
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
|
The input image.
|
|
|
|
Returns:
|
|
list: A list of images.
|
|
"""
|
|
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
|
return [img for img_list in images for img in img_list]
|
|
|
|
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
|
return images
|
|
|
|
elif is_valid_image(images):
|
|
return [images]
|
|
|
|
raise ValueError(f"Could not make batched images from {images}")
|
|
|
|
|
|
def adjust_size(size, patch_size):
|
|
num_patches = size // patch_size
|
|
if num_patches % 2 != 0:
|
|
num_patches -= 1
|
|
return num_patches * patch_size
|
|
|
|
|
|
def smart_resize(
|
|
height: int,
|
|
width: int,
|
|
factor: int = 28,
|
|
min_pixels: int = 28 * 28 * 130,
|
|
max_pixels: int = 28 * 28 * 1280,
|
|
):
|
|
"""Rescales the image so that the following conditions are met:
|
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
|
|
3. The aspect ratio of the image is maintained as closely as possible.
|
|
|
|
"""
|
|
# if height < factor or width < factor:
|
|
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
|
# if int(height < factor//4) + int(width < factor//4):
|
|
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}")
|
|
|
|
if height < factor:
|
|
logging.debug(f"smart_resize: height={height} < factor={factor}, reset height=factor")
|
|
width = round((width * factor) / height)
|
|
height = factor
|
|
|
|
if width < factor:
|
|
logging.debug(f"smart_resize: width={width} < factor={factor}, reset width=factor")
|
|
height = round((height * factor) / width)
|
|
width = factor
|
|
|
|
if max(height, width) / min(height, width) > 200:
|
|
raise ValueError(
|
|
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
|
)
|
|
h_bar = round(height / factor) * factor
|
|
w_bar = round(width / factor) * factor
|
|
if h_bar * w_bar > max_pixels:
|
|
beta = math.sqrt((height * width) / max_pixels)
|
|
h_bar = math.floor(height / beta / factor) * factor
|
|
w_bar = math.floor(width / beta / factor) * factor
|
|
elif h_bar * w_bar < min_pixels:
|
|
beta = math.sqrt(min_pixels / (height * width))
|
|
h_bar = math.ceil(height * beta / factor) * factor
|
|
w_bar = math.ceil(width * beta / factor) * factor
|
|
return h_bar, w_bar
|
|
|
|
|
|
class ImageProcessor(BaseImageProcessor):
|
|
model_input_names = [
|
|
"pixel_values",
|
|
"image_grid_thw",
|
|
"pixel_values_videos",
|
|
"video_grid_thw",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
do_resize: bool = True,
|
|
resample: int = 3,
|
|
do_rescale: bool = True,
|
|
rescale_factor: Union[int, float] = 1 / 255,
|
|
do_normalize: bool = True,
|
|
image_mean: Optional[Union[float, List[float]]] = None,
|
|
image_std: Optional[Union[float, List[float]]] = None,
|
|
do_convert_rgb: bool = True,
|
|
min_pixels: int = 28 * 28 * 130,
|
|
max_pixels: int = 28 * 28 * 1280,
|
|
patch_size: int = 14,
|
|
temporal_patch_size: int = 1,
|
|
merge_size: int = 2,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
self.do_resize = do_resize
|
|
self.resample = resample
|
|
self.do_rescale = do_rescale
|
|
self.rescale_factor = rescale_factor
|
|
self.do_normalize = do_normalize
|
|
self.image_mean = image_mean if image_mean is not None else _OPENAI_CLIP_MEAN
|
|
self.image_std = image_std if image_std is not None else _OPENAI_CLIP_STD
|
|
self.min_pixels = min_pixels
|
|
self.max_pixels = max_pixels
|
|
self.patch_size = patch_size
|
|
self.temporal_patch_size = temporal_patch_size
|
|
self.merge_size = merge_size
|
|
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} # not used
|
|
self.do_convert_rgb = do_convert_rgb
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_dir):
|
|
pretrained_model_dir = Path(pretrained_model_dir)
|
|
image_processor_config_path = pretrained_model_dir / "preprocessor_config.json"
|
|
with open(image_processor_config_path, "r", encoding="utf-8") as f:
|
|
image_processor_config = json.load(f)
|
|
return cls(**image_processor_config)
|
|
|
|
def _preprocess(
|
|
self,
|
|
images,
|
|
do_resize: Optional[bool] = None,
|
|
do_rescale: Optional[bool] = None,
|
|
rescale_factor: Optional[float] = None,
|
|
do_normalize: Optional[bool] = None,
|
|
image_mean: Optional[Union[float, List[float]]] = None,
|
|
image_std: Optional[Union[float, List[float]]] = None,
|
|
do_convert_rgb: Optional[bool] = None,
|
|
):
|
|
images = make_list_of_images(images)
|
|
|
|
if do_convert_rgb:
|
|
images = [image.convert("RGB") for image in images]
|
|
|
|
width, height = images[0].size
|
|
resized_height, resized_width = height, width
|
|
processed_images = []
|
|
|
|
for image in images:
|
|
if do_resize:
|
|
resized_height, resized_width = smart_resize(
|
|
height,
|
|
width,
|
|
factor=self.patch_size * self.merge_size,
|
|
min_pixels=self.min_pixels,
|
|
max_pixels=self.max_pixels,
|
|
)
|
|
|
|
image = image.resize((resized_width, resized_height), resample=self.resample)
|
|
|
|
image = to_numpy_array(image)
|
|
|
|
if do_rescale:
|
|
image = (image * rescale_factor).astype(np.float32)
|
|
|
|
if do_normalize:
|
|
image = image.astype(np.float32)
|
|
image -= np.array(image_mean, dtype=np.float32)
|
|
image /= np.array(image_std, dtype=np.float32)
|
|
|
|
processed_images.append(image)
|
|
|
|
patches = np.array(processed_images)
|
|
patches = patches.transpose(0, 3, 1, 2)
|
|
if patches.shape[0] == 1:
|
|
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
|
channel = patches.shape[1]
|
|
grid_t = patches.shape[0] // self.temporal_patch_size
|
|
grid_h, grid_w = (
|
|
resized_height // self.patch_size,
|
|
resized_width // self.patch_size,
|
|
)
|
|
|
|
patches = patches.reshape(
|
|
grid_t,
|
|
self.temporal_patch_size,
|
|
channel,
|
|
grid_h,
|
|
self.patch_size,
|
|
grid_w,
|
|
self.patch_size,
|
|
)
|
|
patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
|
|
assert self.temporal_patch_size == 1
|
|
flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size)
|
|
return flatten_patches, np.array([grid_t, grid_h, grid_w])
|
|
|
|
def preprocess(
|
|
self,
|
|
images,
|
|
videos=None,
|
|
do_resize: Optional[bool] = None,
|
|
size: Optional[Dict[str, int]] = None,
|
|
do_rescale: Optional[bool] = None,
|
|
rescale_factor: Optional[float] = None,
|
|
do_normalize: Optional[bool] = None,
|
|
image_mean: Optional[Union[float, List[float]]] = None,
|
|
image_std: Optional[Union[float, List[float]]] = None,
|
|
do_convert_rgb: Optional[bool] = None,
|
|
return_tensors=None,
|
|
):
|
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
|
size = size if size is not None else self.size
|
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
|
image_std = image_std if image_std is not None else self.image_std
|
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
|
|
|
if videos is not None:
|
|
raise NotImplementedError("Videos are not yet supported")
|
|
|
|
patches, image_grid_thw = self._preprocess(
|
|
images,
|
|
do_resize=do_resize,
|
|
do_rescale=do_rescale,
|
|
rescale_factor=rescale_factor,
|
|
do_normalize=do_normalize,
|
|
image_mean=image_mean,
|
|
image_std=image_std,
|
|
do_convert_rgb=do_convert_rgb,
|
|
)
|
|
pixel_values = np.array(patches)
|
|
data = {"pixel_values": pixel_values, "grid_thw": image_grid_thw}
|
|
return BatchFeature(data=data, tensor_type=return_tensors)
|