mirror of
https://github.com/s0md3v/roop.git
synced 2025-09-26 20:31:16 +08:00
Make roop more or less type-safe (#541)
* Make roop more or less type-safe * Fix ci.yml * Fix urllib type error * Rename globals in ui
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -13,7 +13,9 @@ jobs:
|
||||
with:
|
||||
python-version: 3.9
|
||||
- run: pip install flake8
|
||||
- run: pip install mypy
|
||||
- run: flake8 run.py roop
|
||||
- run: mypy --config-file mypi.ini run.py roop
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
7
mypi.ini
Normal file
7
mypi.ini
Normal file
@@ -0,0 +1,7 @@
|
||||
[mypy]
|
||||
check_untyped_defs = True
|
||||
disallow_any_generics = True
|
||||
disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
ignore_missing_imports = True
|
||||
strict_optional = False
|
@@ -86,7 +86,7 @@ def parse_args() -> None:
|
||||
roop.globals.execution_providers = decode_execution_providers(['cuda'])
|
||||
if args.gpu_vendor_deprecated == 'amd':
|
||||
print('\033[33mArgument --gpu-vendor amd is deprecated. Use --execution-provider cuda instead.\033[0m')
|
||||
roop.globals.execution_threads = decode_execution_providers(['rocm'])
|
||||
roop.globals.execution_providers = decode_execution_providers(['rocm'])
|
||||
if args.gpu_threads_deprecated:
|
||||
print('\033[33mArgument --gpu-threads is deprecated. Use --execution-threads instead.\033[0m')
|
||||
roop.globals.execution_threads = args.gpu_threads_deprecated
|
||||
|
@@ -1,7 +1,9 @@
|
||||
from typing import List
|
||||
|
||||
source_path = None
|
||||
target_path = None
|
||||
output_path = None
|
||||
frame_processors = []
|
||||
frame_processors: List[str] = []
|
||||
keep_fps = None
|
||||
keep_audio = None
|
||||
keep_frames = None
|
||||
@@ -9,7 +11,7 @@ many_faces = None
|
||||
video_encoder = None
|
||||
video_quality = None
|
||||
max_memory = None
|
||||
execution_providers = []
|
||||
execution_providers: List[str] = []
|
||||
execution_threads = None
|
||||
headless = None
|
||||
log_level = 'error'
|
||||
|
@@ -1,12 +1,13 @@
|
||||
import sys
|
||||
import importlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, List
|
||||
from types import ModuleType
|
||||
from typing import Any, List, Callable
|
||||
from tqdm import tqdm
|
||||
|
||||
import roop
|
||||
|
||||
FRAME_PROCESSORS_MODULES = None
|
||||
FRAME_PROCESSORS_MODULES: List[ModuleType] = []
|
||||
FRAME_PROCESSORS_INTERFACE = [
|
||||
'pre_check',
|
||||
'pre_start',
|
||||
@@ -27,17 +28,17 @@ def load_frame_processor_module(frame_processor: str) -> Any:
|
||||
return frame_processor_module
|
||||
|
||||
|
||||
def get_frame_processors_modules(frame_processors):
|
||||
def get_frame_processors_modules(frame_processors: List[str]) -> List[ModuleType]:
|
||||
global FRAME_PROCESSORS_MODULES
|
||||
if FRAME_PROCESSORS_MODULES is None:
|
||||
FRAME_PROCESSORS_MODULES = []
|
||||
|
||||
if not FRAME_PROCESSORS_MODULES:
|
||||
for frame_processor in frame_processors:
|
||||
frame_processor_module = load_frame_processor_module(frame_processor)
|
||||
FRAME_PROCESSORS_MODULES.append(frame_processor_module)
|
||||
return FRAME_PROCESSORS_MODULES
|
||||
|
||||
|
||||
def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames, progress) -> None:
|
||||
def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], progress: Any = None) -> None:
|
||||
with ThreadPoolExecutor(max_workers=roop.globals.execution_threads) as executor:
|
||||
futures = []
|
||||
for path in temp_frame_paths:
|
||||
@@ -47,7 +48,7 @@ def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_f
|
||||
future.result()
|
||||
|
||||
|
||||
def process_video(source_path: str, frame_paths: list[str], process_frames: Any) -> None:
|
||||
def process_video(source_path: str, frame_paths: list[str], process_frames: Callable[[str, List[str], Any], None]) -> None:
|
||||
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
||||
total = len(frame_paths)
|
||||
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
|
||||
|
@@ -28,14 +28,14 @@ def pre_start() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_face_enhancer() -> None:
|
||||
def get_face_enhancer() -> Any:
|
||||
global FACE_ENHANCER
|
||||
|
||||
with THREAD_LOCK:
|
||||
if FACE_ENHANCER is None:
|
||||
model_path = resolve_relative_path('../models/GFPGANv1.3.pth')
|
||||
# todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1)
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
|
||||
return FACE_ENHANCER
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def process_frame(source_face: Any, temp_frame: Any) -> Any:
|
||||
return temp_frame
|
||||
|
||||
|
||||
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
|
||||
def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None:
|
||||
for temp_frame_path in temp_frame_paths:
|
||||
temp_frame = cv2.imread(temp_frame_path)
|
||||
result = process_frame(None, temp_frame)
|
||||
|
@@ -33,7 +33,7 @@ def pre_start() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_face_swapper() -> None:
|
||||
def get_face_swapper() -> Any:
|
||||
global FACE_SWAPPER
|
||||
|
||||
with THREAD_LOCK:
|
||||
@@ -60,7 +60,7 @@ def process_frame(source_face: Any, temp_frame: Any) -> Any:
|
||||
return temp_frame
|
||||
|
||||
|
||||
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
|
||||
def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None:
|
||||
source_face = get_one_face(cv2.imread(source_path))
|
||||
for temp_frame_path in temp_frame_paths:
|
||||
temp_frame = cv2.imread(temp_frame_path)
|
||||
|
28
roop/ui.py
28
roop/ui.py
@@ -12,16 +12,26 @@ from roop.predicter import predict_frame
|
||||
from roop.processors.frame.core import get_frame_processors_modules
|
||||
from roop.utilities import is_image, is_video, resolve_relative_path
|
||||
|
||||
WINDOW_HEIGHT = 700
|
||||
WINDOW_WIDTH = 600
|
||||
ROOT = None
|
||||
ROOT_HEIGHT = 700
|
||||
ROOT_WIDTH = 600
|
||||
|
||||
PREVIEW = None
|
||||
PREVIEW_MAX_HEIGHT = 700
|
||||
PREVIEW_MAX_WIDTH = 1200
|
||||
|
||||
RECENT_DIRECTORY_SOURCE = None
|
||||
RECENT_DIRECTORY_TARGET = None
|
||||
RECENT_DIRECTORY_OUTPUT = None
|
||||
|
||||
preview_label = None
|
||||
preview_slider = None
|
||||
source_label = None
|
||||
target_label = None
|
||||
status_label = None
|
||||
|
||||
def init(start: Callable, destroy: Callable) -> ctk.CTk:
|
||||
|
||||
def init(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
|
||||
global ROOT, PREVIEW
|
||||
|
||||
ROOT = create_root(start, destroy)
|
||||
@@ -30,14 +40,14 @@ def init(start: Callable, destroy: Callable) -> ctk.CTk:
|
||||
return ROOT
|
||||
|
||||
|
||||
def create_root(start: Callable, destroy: Callable) -> ctk.CTk:
|
||||
def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
|
||||
global source_label, target_label, status_label
|
||||
|
||||
ctk.deactivate_automatic_dpi_awareness()
|
||||
ctk.set_appearance_mode('system')
|
||||
ctk.set_default_color_theme(resolve_relative_path('ui.json'))
|
||||
root = ctk.CTk()
|
||||
root.minsize(WINDOW_WIDTH, WINDOW_HEIGHT)
|
||||
root.minsize(ROOT_WIDTH, ROOT_HEIGHT)
|
||||
root.title('roop')
|
||||
root.configure()
|
||||
root.protocol('WM_DELETE_WINDOW', lambda: destroy())
|
||||
@@ -85,7 +95,7 @@ def create_root(start: Callable, destroy: Callable) -> ctk.CTk:
|
||||
return root
|
||||
|
||||
|
||||
def create_preview(parent) -> ctk.CTkToplevel:
|
||||
def create_preview(parent: ctk.CTkToplevel) -> ctk.CTkToplevel:
|
||||
global preview_label, preview_slider
|
||||
|
||||
preview = ctk.CTkToplevel(parent)
|
||||
@@ -143,7 +153,7 @@ def select_target_path() -> None:
|
||||
target_label.configure(image=None)
|
||||
|
||||
|
||||
def select_output_path(start):
|
||||
def select_output_path(start: Callable[[], None]) -> None:
|
||||
global RECENT_DIRECTORY_OUTPUT
|
||||
|
||||
if is_image(roop.globals.target_path):
|
||||
@@ -158,14 +168,14 @@ def select_output_path(start):
|
||||
start()
|
||||
|
||||
|
||||
def render_image_preview(image_path: str, size: Tuple[int, int] = None) -> ctk.CTkImage:
|
||||
def render_image_preview(image_path: str, size: Tuple[int, int]) -> ctk.CTkImage:
|
||||
image = Image.open(image_path)
|
||||
if size:
|
||||
image = ImageOps.fit(image, size, Image.LANCZOS)
|
||||
return ctk.CTkImage(image, size=image.size)
|
||||
|
||||
|
||||
def render_video_preview(video_path: str, size: Tuple[int, int] = None, frame_number: int = 0) -> ctk.CTkImage:
|
||||
def render_video_preview(video_path: str, size: Tuple[int, int], frame_number: int = 0) -> ctk.CTkImage:
|
||||
capture = cv2.VideoCapture(video_path)
|
||||
if frame_number:
|
||||
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
|
@@ -7,7 +7,7 @@ import ssl
|
||||
import subprocess
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Any
|
||||
from tqdm import tqdm
|
||||
|
||||
import roop.globals
|
||||
@@ -76,7 +76,7 @@ def get_temp_output_path(target_path: str) -> str:
|
||||
return os.path.join(temp_directory_path, TEMP_FILE)
|
||||
|
||||
|
||||
def normalize_output_path(source_path: str, target_path: str, output_path: str) -> str:
|
||||
def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any:
|
||||
if source_path and target_path:
|
||||
source_name, _ = os.path.splitext(os.path.basename(source_path))
|
||||
target_name, target_extension = os.path.splitext(os.path.basename(target_path))
|
||||
@@ -114,14 +114,14 @@ def has_image_extension(image_path: str) -> bool:
|
||||
def is_image(image_path: str) -> bool:
|
||||
if image_path and os.path.isfile(image_path):
|
||||
mimetype, _ = mimetypes.guess_type(image_path)
|
||||
return mimetype and mimetype.startswith('image/')
|
||||
return bool(mimetype and mimetype.startswith('image/'))
|
||||
return False
|
||||
|
||||
|
||||
def is_video(video_path: str) -> bool:
|
||||
if video_path and os.path.isfile(video_path):
|
||||
mimetype, _ = mimetypes.guess_type(video_path)
|
||||
return mimetype and mimetype.startswith('video/')
|
||||
return bool(mimetype and mimetype.startswith('video/'))
|
||||
return False
|
||||
|
||||
|
||||
@@ -131,10 +131,10 @@ def conditional_download(download_directory_path: str, urls: List[str]) -> None:
|
||||
for url in urls:
|
||||
download_file_path = os.path.join(download_directory_path, os.path.basename(url))
|
||||
if not os.path.exists(download_file_path):
|
||||
request = urllib.request.urlopen(url)
|
||||
request = urllib.request.urlopen(url) # type: ignore[attr-defined]
|
||||
total = int(request.headers.get('Content-Length', 0))
|
||||
with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress:
|
||||
urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size))
|
||||
urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def resolve_relative_path(path: str) -> str:
|
||||
|
Reference in New Issue
Block a user