Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
412 changes: 329 additions & 83 deletions app/processors/frame_edits.py

Large diffs are not rendered by default.

129 changes: 129 additions & 0 deletions app/processors/models_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,117 @@ def _check_tensorrt_cache(self, model_name: str, onnx_path: str) -> bool:
print(f"[ERROR] Failed TensorRT cache check: {e}")
return False

def _clean_tensorrt_cache(self, onnx_path: str, trt_options: dict) -> None:
"""
Cleans up potentially corrupted TensorRT cache files for a specific model.
Safely ignores missing files or locked files to prevent crashes during the cleanup process.

Args:
onnx_path (str): The local path to the ONNX model.
trt_options (dict): The TensorRT options containing the dynamic cache path.
"""
import os
import re

cache_dir = trt_options.get("trt_engine_cache_path", "tensorrt-engines")
base_onnx_name = os.path.splitext(os.path.basename(onnx_path))[0]

# 1. Try to read the context file to find the specific engine file before deleting it
ctx_file_name = f"{base_onnx_name}_ctx.onnx"
ctx_file_path = os.path.join(cache_dir, ctx_file_name)

engine_file_paths_to_check = []
if os.path.exists(ctx_file_path) and os.path.isfile(ctx_file_path):
try:
with open(ctx_file_path, "rb") as f:
content = f.read()

# Extract the engine name generated by ONNX Runtime
match = re.search(b"TensorrtExecutionProvider_.*?\\.engine", content)
if match:
engine_name = match.group(0).decode("utf-8")

# Failsafe: ORT pathing behavior varies.
engine_subdirectory_name = os.path.basename(cache_dir)
engine_file_paths_to_check.extend(
[
os.path.join(cache_dir, engine_name),
os.path.join(
cache_dir, engine_subdirectory_name, engine_name
),
]
)
except Exception as e:
print(
f"[WARN] Could not read corrupted context file {ctx_file_path} to find engine name: {e}"
)

# 2. Delete the context file
if os.path.exists(ctx_file_path) and os.path.isfile(ctx_file_path):
try:
os.remove(ctx_file_path)
print(
f"[INFO] Deleted corrupted TensorRT context file: {ctx_file_path}"
)
except Exception as e:
print(
f"[WARN] Failed to delete {ctx_file_path} (it might be locked or missing): {e}"
)

# 3. Delete the engine file(s) if we found them
for engine_path in engine_file_paths_to_check:
if (
engine_path
and os.path.exists(engine_path)
and os.path.isfile(engine_path)
):
try:
os.remove(engine_path)
print(
f"[INFO] Deleted corrupted TensorRT engine file: {engine_path}"
)
except Exception as e:
print(f"[WARN] Failed to delete engine file {engine_path}: {e}")

# 4. Delete any associated timing cache, profile files, or general cache files
if os.path.exists(cache_dir) and os.path.isdir(cache_dir):
try:
for file_name in os.listdir(cache_dir):
# Catch model-specific files (e.g., SomeModel.profile)
is_model_specific = file_name.startswith(base_onnx_name) and (
file_name.endswith(".profile")
or file_name.endswith(".cache")
or file_name.endswith(".timing")
)

# Catch exact generic names (like DFM's "timing.cache")
is_generic_timing = file_name == "timing.cache"

# Catch ORT's global architecture-based timing caches
# Example: TensorrtExecutionProvider_cache_sm120.timing
is_ort_global_timing = file_name.startswith(
"TensorrtExecutionProvider_"
) and (
file_name.endswith(".timing") or file_name.endswith(".profile")
)

if is_model_specific or is_generic_timing or is_ort_global_timing:
target_path = os.path.join(cache_dir, file_name)
if os.path.isfile(target_path):
try:
os.remove(target_path)
print(
f"[INFO] Deleted TensorRT auxiliary/timing file: {target_path}"
)
except Exception as e:
print(
f"[WARN] Failed to delete auxiliary file {target_path}: {e}"
)
except Exception as e:
print(
f"[WARN] Failed to clean profile/timing/cache files in {cache_dir}: {e}"
)

def load_model(self, model_name, session_options=None):
"""
Loads an AI model (ONNX) with thread safety.
Expand Down Expand Up @@ -680,6 +791,15 @@ def load_model(self, model_name, session_options=None):
)
probe_process.terminate()
probe_process.join()

# Clean up corrupted caches caused by the timeout before raising
print(
f"[INFO] Cleaning up corrupted TensorRT cache for {model_name} due to timeout..."
)
self._clean_tensorrt_cache(
onnx_path, model_trt_options
)

raise RuntimeError(
"TensorRT Engine build timed out."
)
Expand All @@ -700,6 +820,15 @@ def load_model(self, model_name, session_options=None):
print(
f"[WARN] Probe attempt {attempt + 1} failed with exit code {exitcode}."
)

# Wipe corrupted artifacts before attempting the next retry
print(
f"[INFO] Cleaning up potentially corrupted TensorRT cache for {model_name}..."
)
self._clean_tensorrt_cache(
onnx_path, model_trt_options
)

if attempt < max_retries - 1:
print("[INFO] Retrying in 2 seconds...")
time.sleep(2.0)
Expand Down
165 changes: 139 additions & 26 deletions app/processors/utils/faceutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,8 +2001,67 @@ def calculate_distance_ratio(
def calc_eye_close_ratio(
lmk: np.ndarray, target_eye_ratio: np.ndarray = None
) -> np.ndarray:
lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
"""
Calculates the Eye Aspect Ratio (EAR) with strict projection safeguards.
Includes Profile Occlusion Detection and Symmetric Blink Harmonization
to completely eliminate "fisheyes" and "lazy eyes".

Args:
lmk: Array of shape (N, 203, 2) or (1, 203, 2) containing landmarks.
target_eye_ratio: Optional target ratio to concatenate.

Returns:
np.ndarray: The safely clamped and harmonized eye ratios.
"""
# 1. Calculate raw horizontal width of the eyes
raw_left_width = np.linalg.norm(lmk[:, 0] - lmk[:, 12], axis=1, keepdims=True)
raw_right_width = np.linalg.norm(lmk[:, 24] - lmk[:, 36], axis=1, keepdims=True)

# SAFEGUARD A: Profile Occlusion Detection (The Fisheye Fix)
# If one eye is significantly narrower horizontally than the other (< 55%),
# the face is turned. The hidden eye's 2D landmarks are unreliable.
left_occluded = raw_left_width < (raw_right_width * 0.55)
right_occluded = raw_right_width < (raw_left_width * 0.55)

# SAFEGUARD B: Clamp minimum width to prevent ZeroDivision on extreme squishing
min_eye_width = 4.0
left_eye_width = np.maximum(raw_left_width, min_eye_width)
right_eye_width = np.maximum(raw_right_width, min_eye_width)

# 2. Calculate vertical height of the eyes
left_eye_height = np.linalg.norm(lmk[:, 6] - lmk[:, 18], axis=1, keepdims=True)
right_eye_height = np.linalg.norm(lmk[:, 30] - lmk[:, 42], axis=1, keepdims=True)

# 3. Calculate Base Ratios
lefteye_close_ratio = left_eye_height / left_eye_width
righteye_close_ratio = right_eye_height / right_eye_width

# SAFEGUARD C: Apply Occlusion Lock
# Force the hidden eye to perfectly mirror the visible eye's EAR.
# This prevents the network from rendering a bulging wide-open eye.
lefteye_close_ratio = np.where(
left_occluded, righteye_close_ratio, lefteye_close_ratio
)
righteye_close_ratio = np.where(
right_occluded, lefteye_close_ratio, righteye_close_ratio
)

# SAFEGUARD D: Symmetric Blink Harmonization (Anti "Lazy-Eye")
blink_threshold = 0.28
is_blinking = (lefteye_close_ratio < blink_threshold) & (
righteye_close_ratio < blink_threshold
)

avg_ratio = (lefteye_close_ratio + righteye_close_ratio) / 2.0

lefteye_close_ratio = np.where(is_blinking, avg_ratio, lefteye_close_ratio)
righteye_close_ratio = np.where(is_blinking, avg_ratio, righteye_close_ratio)

# SAFEGUARD E: Hard clamp the final ratio to biologically plausible limits.
max_safe_ear = 0.45
lefteye_close_ratio = np.clip(lefteye_close_ratio, 0.0, max_safe_ear)
righteye_close_ratio = np.clip(righteye_close_ratio, 0.0, max_safe_ear)

if target_eye_ratio is not None:
return np.concatenate(
[lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1
Expand All @@ -2012,8 +2071,42 @@ def calc_eye_close_ratio(


# imported from https://github.com/KwaiVGI/LivePortrait/blob/main/src/utils/live_portrait_wrapper.py
# def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
# return calculate_distance_ratio(lmk, 90, 102, 48, 66)
def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
return calculate_distance_ratio(lmk, 90, 102, 48, 66)
"""
Calculates the Mouth Aspect Ratio (MAR) with strict projection safeguards.
Prevents division by zero on profile faces or extreme pouting,
which causes the lower face to collapse or the mouth to stretch unnaturally.

Args:
lmk: Array of shape (N, 203, 2) or (1, 203, 2) containing landmarks.

Returns:
np.ndarray: The clamped lip ratios to safely feed the retargeting network.
"""
# 1. Calculate horizontal width of the mouth (Denominator)
# Indices based on 203-point format: Left mouth corner (48), Right mouth corner (66)
mouth_width = np.linalg.norm(lmk[:, 48] - lmk[:, 66], axis=1, keepdims=True)

# SAFEGUARD A: Clamp minimum width to prevent MAR explosion.
# A mouth width below 8.0 pixels implies extreme profile, heavy occlusion, or severe pout.
min_mouth_width = 8.0
mouth_width = np.maximum(mouth_width, min_mouth_width)

# 2. Calculate vertical height of the lips (Numerator)
# Indices: Upper lip center (90), Lower lip center (102)
lip_height = np.linalg.norm(lmk[:, 90] - lmk[:, 102], axis=1, keepdims=True)

# 3. Calculate Base Ratio
mar = lip_height / mouth_width

# SAFEGUARD B: Hard clamp the final ratio to biologically plausible limits (0.0 to 0.85).
# Normal human mouth aspect ratio rarely exceeds 0.75 even when shouting or yawning.
max_safe_mar = 0.85
mar = np.clip(mar, 0.0, max_safe_mar)

return mar


# imported from https://github.com/KwaiVGI/LivePortrait/blob/main/src/utils/camera.py
Expand Down Expand Up @@ -2310,40 +2403,60 @@ def update_delta_new_mov_y(mov_y, delta_new, **kwargs):

# imported from https://github.com/KwaiVGI/LivePortrait/blob/main/src/utils/live_portrait_wrapper.py
def calc_combined_eye_ratio(c_d_eyes_i, source_lmk, device="cuda"):
"""
FIX: Averages the driving eye ratios to prevent left-eye dominance bias.
Ensures symmetric baseline retargeting for the LivePortrait generator.
"""
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(device)
# c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(device)
c_d_eyes_i_numpy_m = np.array(
[c_d_eyes_i[0][0]], dtype=np.float32
) # Assicurati che sia un array NumPy
c_d_eyes_i_numpy = np.array(
[max(c_d_eyes_i_numpy_m, 0.08)], dtype=np.float32
) # Mini 0.08 otherwise eyelids overlap

# Safely extract left and right eye ratios
left_eye_ratio = c_d_eyes_i[0][0]
right_eye_ratio = c_d_eyes_i[0][1] if len(c_d_eyes_i[0]) > 1 else left_eye_ratio

# Calculate the mean to harmonize the retargeting delta
mean_eye_ratio = (left_eye_ratio + right_eye_ratio) / 2.0

c_d_eyes_i_numpy_m = np.array([mean_eye_ratio], dtype=np.float32)

# Minimum 0.08 clamp to prevent eyelid mesh overlapping (Z-fighting)
c_d_eyes_i_numpy = np.array([max(c_d_eyes_i_numpy_m[0], 0.08)], dtype=np.float32)
c_d_eyes_i_tensor = torch.from_numpy(c_d_eyes_i_numpy).reshape(1, 1).to(device)
# [c_s,eyes, c_d,eyes,i]

# Format: [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)

return combined_eye_ratio_tensor


def calc_combined_eye_ratio_norm(c_d_eyes_i, source_lmk, device="cuda"):
def calc_independent_eye_ratios(
c_d_eyes_i: np.ndarray, source_lmk: np.ndarray, device: str = "cuda"
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculates separate retargeting tensors for Left and Right eyes.
Enables the 'Split-Eye' asymmetric blink (winking) trick.
"""
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(device)
# c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(device)
c_d_eyes_i_numpy_l = np.array(
[c_d_eyes_i[0][0]], dtype=np.float32
) # Assicurati che sia un array NumPy
c_d_eyes_i_numpy_r = np.array(
[c_d_eyes_i[0][1]], dtype=np.float32
) # Assicurati che sia un array NumPy
c_d_eyes_i_numpy = np.array(
[max(min(c_d_eyes_i_numpy_l, c_d_eyes_i_numpy_r), 0.08)], dtype=np.float32
) # Mini 0.08 otherwise eyelids overlap
c_d_eyes_i_tensor = torch.from_numpy(c_d_eyes_i_numpy).reshape(1, 1).to(device)
# [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)

return combined_eye_ratio_tensor
# Safely extract left and right eye ratios
left_eye_ratio = float(c_d_eyes_i[0][0])
right_eye_ratio = (
float(c_d_eyes_i[0][1]) if len(c_d_eyes_i[0]) > 1 else left_eye_ratio
)

# Clamp to 0.08 minimum to avoid 3D mesh overlap (Z-fighting on eyelids)
left_val = np.array([max(left_eye_ratio, 0.08)], dtype=np.float32)
right_val = np.array([max(right_eye_ratio, 0.08)], dtype=np.float32)

left_tensor = torch.from_numpy(left_val).reshape(1, 1).to(device)
right_tensor = torch.from_numpy(right_val).reshape(1, 1).to(device)

# Format: [c_s_left, c_s_right, target_specific_eye]
ratio_left_target = torch.cat([c_s_eyes_tensor, left_tensor], dim=1)
ratio_right_target = torch.cat([c_s_eyes_tensor, right_tensor], dim=1)

return ratio_left_target, ratio_right_target


# imported from https://github.com/KwaiVGI/LivePortrait/blob/main/src/utils/live_portrait_wrapper.py
Expand Down
Loading