Skip to content
Draft

Main #176

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added dot_model_checkpoints.zip
Empty file.
13 changes: 10 additions & 3 deletions src/dot/commons/model_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union

import cv2
import torch
try:
import cv2
except ModuleNotFoundError:
cv2 = None
try:
import torch
except ModuleNotFoundError:
torch = None

from ..gpen.face_enhancement import FaceEnhancement
from .camera_utils import camera_pipeline, fetch_camera
from .utils import find_images_from_path, generate_random_file_idx, rand_idx_tuple
from .video.video_utils import video_pipeline
# from .video.video_utils import video_pipeline # Deferred import


class ModelOption(ABC):
Expand Down Expand Up @@ -211,6 +217,7 @@ def generate_from_video(
duration (int): Trim target video in seconds.
limit (int, optional): Limit number of video-swaps. Defaults to None.
"""
from .video.video_utils import video_pipeline # Deferred import
with torch.no_grad():
self.create_model(**kwargs)
video_pipeline(
Expand Down
2 changes: 1 addition & 1 deletion src/dot/commons/video/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _crop_and_pose(
)

try:
crop_image = image[ytop:ybot, xleft:xright]
crop_image = image[yinset-block-start:ybot, xinset-inline-start:xright]
if estimate_pose:
if pose_estimation(image=crop_image, roll=3, pitch=3, yaw=3) != 0:
return -1
Expand Down
6 changes: 5 additions & 1 deletion src/dot/gpen/retinaface/retinaface_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ class RetinaFaceDetection(object):
def __init__(self, base_dir, network="RetinaFace-R50", use_gpu=True):
torch.set_grad_enabled(False)
cudnn.benchmark = True
self.pretrained_path = os.path.join(base_dir, "weights", network + ".pth")
# If base_dir is empty or relative path, use saved_models/gpen/weights
if not base_dir or base_dir == "./":
self.pretrained_path = os.path.join("saved_models", "gpen", "weights", network + ".pth")
else:
self.pretrained_path = os.path.join(base_dir, "saved_models", "gpen", "weights", network + ".pth")
if use_gpu:
self.device = "mps" if torch.backends.mps.is_available() else "cuda"
else:
Expand Down
23 changes: 15 additions & 8 deletions src/dot/simswap/mediapipe/face_mesh.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#!/usr/bin/env python3

from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, TYPE_CHECKING

import cv2
import mediapipe as mp
import numpy as np
from mediapipe.framework.formats.landmark_pb2 import NormalizedLandmark

if TYPE_CHECKING:
from mediapipe.framework.formats.landmark_pb2 import NormalizedLandmark

from .utils import face_align_ffhqandnewarc as face_align
from .utils import mediapipe_landmarks

mp_face_mesh = mp.solutions.face_mesh


class FaceMesh:
"""Wrapper class of Mediapipe's FaceMesh module. Extracts facial landmarks
Expand All @@ -38,18 +37,26 @@ def __init__(
min_detection_confidence: float = 0.5,
mode: str = "None",
):
# Deferred MediaPipe import
try:
import mediapipe as mp
from mediapipe.framework.formats.landmark_pb2 import NormalizedLandmark
self.mp_face_mesh = mp.solutions.face_mesh
except ImportError as e:
raise RuntimeError(f"MediaPipe import failed: {e}")

self.MediaPipeIds = mediapipe_landmarks.MediaPipeLandmarks
self.static_image_mode = static_image_mode
self.max_num_faces = max_num_faces
self.refine_landmarks = refine_landmarks
self.min_detection_confidence = min_detection_confidence
self.mode = mode

def _get_centroid(self, landmarks: List[NormalizedLandmark]) -> Tuple[float, float]:
def _get_centroid(self, landmarks: List["NormalizedLandmark"]) -> Tuple[float, float]:
"""Given a set of normalized landmarks/points finds centroid point

Args:
landmarks (List[NormalizedLandmark]): List of relative points that form a polygon
landmarks (List["NormalizedLandmark"]): List of relative points that form a polygon

Returns:
Tuple[float, float]: x,y coordinates of polygon centroid
Expand Down Expand Up @@ -88,7 +95,7 @@ def get_face_landmarks(self, image: np.ndarray) -> Optional[np.array]:
"""
# keypoints for all detected faces
detection_kpss = []
with mp_face_mesh.FaceMesh(
with self.mp_face_mesh.FaceMesh(
static_image_mode=self.static_image_mode,
max_num_faces=self.max_num_faces,
refine_landmarks=self.refine_landmarks,
Expand Down