Files
FastDeploy/fastdeploy/input/mm_processor/utils/io_utils.py
2025-07-19 23:19:27 +08:00

265 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import base64
import datetime
import hashlib
import io
import os
import threading
import uuid
from pathlib import Path
import numpy as np
import requests
from PIL import Image
from PIL.ExifTags import TAGS
RAW_VIDEO_DIR = "./download_tmp/raw_video/"
RAW_IMAGE_DIR = "./download_tmp/raw_images/"
EXTRACTED_FRAME_DIR = "./download_tmp/extracted_frames/"
TMP_DIR = "./download_tmp/upload_tmp/"
def file_download(url, download_dir, save_to_disk=False, retry=0, retry_interval=3):
"""
Description: 下载url如果url是PIL直接返回
Args:
url(str, PIL): http/本地路径/io.Bytes注意io.Bytes是图片字节流
download_path: 在save_to_disk=True的情况下生效返回保存地址
save_to_disk: 是否保存在本地路径
"""
from .video_utils import VideoReaderWrapper
if isinstance(url, Image.Image):
return url
elif isinstance(url, VideoReaderWrapper):
return url
elif url.startswith("http"):
response = requests.get(url)
bytes_data = response.content
elif os.path.isfile(url):
if save_to_disk:
return url
bytes_data = open(url, "rb").read()
else:
bytes_data = base64.b64decode(url)
if not save_to_disk:
return bytes_data
download_path = os.path.join(download_dir, get_filename(url))
Path(download_path).parent.mkdir(parents=True, exist_ok=True)
with open(download_path, "wb") as f:
f.write(bytes_data)
return download_path
def get_filename(url=None):
"""
Get Filename
"""
if url is None:
return str(uuid.uuid4()).replace("-", "")
t = datetime.datetime.now()
if not isinstance(url, bytes):
url = url.encode("utf-8")
md5_hash = hashlib.md5(url).hexdigest()
pid = os.getpid()
tid = threading.get_ident()
# 去掉后缀防止save-jpg报错
image_filname = f"{t.year}-{t.month:02d}-{t.day:02d}-{pid}-{tid}-{md5_hash}"
return image_filname
def get_downloadable(
url,
download_dir=RAW_VIDEO_DIR,
save_to_disk=False,
retry=0,
retry_interval=3,
):
"""download video and store it in the disk
return downloaded **path** if save_to_disk is set to true
return downloaded **bytes** if save_to_disk is set to false
"""
if not os.path.exists(download_dir):
os.makedirs(download_dir)
downloaded_path = file_download(
url,
download_dir,
save_to_disk=save_to_disk,
retry=retry,
retry_interval=retry_interval,
)
return downloaded_path
def get_downloadable_image(download_path, need_exif_info, retry_max_time=0, retry_interval=3):
"""
带上exif info和图像处理的get downloadable
"""
def get_image_exif(image):
exif_data = image._getexif()
exif_info = {}
if exif_data is not None:
for tag, value in exif_data.items():
tag_name = TAGS.get(tag, tag)
exif_info[tag_name] = value.strip()
return exif_info
def has_transparent_background(img):
"""判断图片是否有背景"""
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
# Check for any pixel with alpha channel less than 255 (fully opaque)
alpha = img.convert("RGBA").split()[-1]
if alpha.getextrema()[0] < 255:
return True
return False
def add_white_background(img):
"""
给透明背景的图,加个白色背景
"""
if img.mode != "RGBA":
img = img.convert("RGBA")
# 创建一个白色背景的图像,尺寸与原图一致
img_white_background = Image.new("RGBA", img.size, (255, 255, 255))
# 将原图粘贴到白色背景上
img_white_background.paste(img, (0, 0), img)
return img_white_background
def change_I16_to_L(img):
"""
将图片从I;16模式转换为L模式
"""
# 由于I模式的point函数只支持加减乘所以下面的* (1 / 256)不能改成除法
return img.point(lambda i: i * (1 / 256)).convert("L")
image = get_downloadable(
download_path,
save_to_disk=False,
retry=retry_max_time,
retry_interval=retry_interval,
)
if isinstance(image, Image.Image):
pil_image = image
else:
pil_image = Image.open(io.BytesIO(image))
if need_exif_info:
try:
exif_info = get_image_exif(pil_image)
except Exception:
exif_info = {}
else:
exif_info = {}
try:
if pil_image.mode == "I;16":
pil_image = change_I16_to_L(pil_image)
if has_transparent_background(pil_image):
pil_image = add_white_background(pil_image)
except Exception:
pass
return pil_image.convert("RGB"), exif_info
def str2hash(url):
"""
从一个str的到url
"""
return hashlib.sha256(url.encode()).hexdigest()
def pil2hash(pil):
"""
从一个PIL.Image到hash
"""
byte_io = io.BytesIO()
pil.save(byte_io, format="PNG") # 选择无损格式,避免压缩影响
image_bytes = byte_io.getvalue()
return hashlib.sha256(image_bytes).hexdigest()
def imagepath_to_base64(image_path):
"""imagepath_to_base64"""
image = Image.open(image_path).convert("RGB")
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
return base64_encoded
def pil_image_to_base64(image):
"""pil_image_to_base64"""
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
return base64_encoded
def http_to_pil_image(url):
"""http_to_pil_image"""
response = requests.get(url)
image_data = io.BytesIO(response.content)
pil_image = Image.open(image_data).convert("RGB")
return pil_image
def http_to_image_base64(url):
"""http_to_image_base64"""
response = requests.get(url)
image_data = io.BytesIO(response.content)
return base64.b64encode(image_data.getvalue()).decode("utf-8")
def base64_to_pil_image(base64_string):
""" " base64_to_pil_image"""
image_bytes = base64.b64decode(base64_string)
buffer = io.BytesIO(image_bytes)
image = Image.open(buffer)
return image
def get_hashable(to_be_hashed):
"""get hashable"""
if isinstance(to_be_hashed, bytes):
return to_be_hashed
elif isinstance(to_be_hashed, Image.Image):
return to_be_hashed.tobytes()
elif isinstance(to_be_hashed, str):
return to_be_hashed.encode("utf-8")
else:
raise ValueError(f"not support type: {type(to_be_hashed)}")
def load_dict_from_npz(npzfile):
"""从npz文件读取数据"""
with np.load(npzfile, allow_pickle=True) as data:
loaded_dict = {key: data[key] for key in data.files}
return loaded_dict