-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
436 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
import os | ||
import hashlib | ||
import random | ||
import imghdr | ||
import numpy as np | ||
from PIL import Image, ImageOps, ImageSequence | ||
import torch | ||
from server import folder_paths # Ensure this import is correct | ||
|
||
# class PreviewImage: | ||
# def __init__(self): | ||
# self.type = "temp" | ||
# self.prefix_append = "_preview_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) | ||
|
||
# RETURN_TYPES = () | ||
# FUNCTION = "preview_images" | ||
# OUTPUT_NODE = True | ||
# CATEGORY = "image" | ||
|
||
# def preview_images(self, images): | ||
# results = list() | ||
# for idx in range(images.shape[0]): | ||
# image = images[idx] | ||
# img_array = 255. * image.cpu().numpy() | ||
# img = Image.fromarray(np.clip(img_array, 0, 255).astype(np.uint8)) | ||
|
||
# # Display the image | ||
# img.show() | ||
|
||
# filename = f"preview_{idx:05d}{self.prefix_append}.png" | ||
# results.append({ | ||
# "filename": filename, | ||
# "subfolder": "", | ||
# "type": self.type | ||
# }) | ||
|
||
# return { "ui": { "images": results } } | ||
|
||
class LoadImagePlus: | ||
def __init__(self): | ||
self.img_extensions = [".png", ".jpg", ".jpeg", ".bmp", ".webp"] | ||
# self.preview_image = PreviewImage() | ||
|
||
@classmethod | ||
def INPUT_TYPES(cls): | ||
input_dir = folder_paths.get_input_directory() # Ensure this method is correct | ||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] | ||
return { | ||
"required": { | ||
"image": (sorted(files), {"image_upload": True}), | ||
"use_random_image": ("BOOLEAN", {"default": False}), | ||
"random_folder": ("STRING", {"default": "."}), | ||
"n_images": ("INT", {"default": 1, "min": 1, "max": 100}), | ||
"seed": ("INT", {"default": 0, "min": 0, "max": 100000}), | ||
"sort": ("BOOLEAN", {"default": False}), | ||
"loop_sequence": ("BOOLEAN", {"default": False}), | ||
# "enable_preview": ("BOOLEAN", {"default": True}), | ||
} | ||
} | ||
|
||
CATEGORY = "🧔🏻♂️🇰 🇪 🇼 🇰 " | ||
RETURN_TYPES = ("IMAGE", "MASK") | ||
FUNCTION = "load_image" | ||
|
||
def load_image(self, image, use_random_image, random_folder, n_images, seed, sort, loop_sequence): | ||
if use_random_image: | ||
output_image, output_mask = self.load_random_image(random_folder, n_images, seed, sort, loop_sequence) | ||
else: | ||
output_image, output_mask = self.load_specific_image(image) | ||
|
||
# if enable_preview: | ||
# preview_result = self.preview_image.preview_images(output_image) | ||
# return (output_image, output_mask, preview_result) | ||
# else: | ||
return (output_image, output_mask) | ||
|
||
def load_specific_image(self, image): | ||
image_path = folder_paths.get_annotated_filepath(image) | ||
img = Image.open(image_path) | ||
|
||
output_images = [] | ||
output_masks = [] | ||
w, h = None, None | ||
|
||
excluded_formats = ['MPO'] | ||
|
||
for i in ImageSequence.Iterator(img): | ||
i = ImageOps.exif_transpose(i) | ||
|
||
if i.mode == 'I': | ||
i = i.point(lambda i: i * (1 / 255)) | ||
image = i.convert("RGB") | ||
|
||
if len(output_images) == 0: | ||
w = image.size[0] | ||
h = image.size[1] | ||
|
||
if image.size[0] != w or image.size[1] != h: | ||
continue | ||
|
||
image = np.array(image).astype(np.float32) / 255.0 | ||
image = torch.from_numpy(image)[None,] | ||
if 'A' in i.getbands(): | ||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 | ||
mask = 1. - torch.from_numpy(mask) | ||
else: | ||
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") | ||
output_images.append(image) | ||
output_masks.append(mask.unsqueeze(0)) | ||
|
||
if len(output_images) > 1 and img.format not in excluded_formats: | ||
output_image = torch.cat(output_images, dim=0) | ||
output_mask = torch.cat(output_masks, dim=0) | ||
else: | ||
output_image = output_images[0] | ||
output_mask = output_masks[0] | ||
|
||
return (output_image, output_mask) | ||
|
||
def load_random_image(self, folder, n_images, seed, sort, loop_sequence): | ||
files = [os.path.join(folder, f) for f in os.listdir(folder)] | ||
files = [f for f in files if os.path.isfile(f)] | ||
files = [f for f in files if any([f.endswith(ext) for ext in self.img_extensions])] | ||
files = [f for f in files if imghdr.what(f)] | ||
|
||
random.seed(seed) | ||
random.shuffle(files) | ||
|
||
image_paths = files[:n_images] | ||
|
||
if sort: | ||
image_paths = sorted(image_paths) | ||
|
||
imgs = [Image.open(image_path) for image_path in image_paths] | ||
output_images = [] | ||
for img in imgs: | ||
img = ImageOps.exif_transpose(img) | ||
if img.mode == 'I': | ||
img = img.point(lambda i: i * (1 / 255)) | ||
image = img.convert("RGB") | ||
image = np.array(image).astype(np.float32) / 255.0 | ||
output_images.append(image) | ||
|
||
if loop_sequence: | ||
output_images.append(output_images[0]) | ||
|
||
if len(output_images) > 1: | ||
# output_images = self.get_uniformly_sized_crops(output_images, target_n_pixels=1024**2) | ||
output_images = [torch.from_numpy(output_image)[None,] for output_image in output_images] | ||
output_image = torch.cat(output_images, dim=0) | ||
else: | ||
output_image = torch.from_numpy(output_images[0])[None,] | ||
|
||
# Create a dummy mask | ||
mask = torch.zeros((output_image.shape[0], 64, 64), dtype=torch.float32, device="cpu") | ||
|
||
return (output_image, mask) | ||
|
||
# @staticmethod | ||
# def get_uniformly_sized_crops(images, target_n_pixels): | ||
# resized_images = [] | ||
# for img in images: | ||
# h, w, _ = img.shape | ||
# aspect_ratio = w / h | ||
# if aspect_ratio > 1: | ||
# new_w, new_h = 512, int(512 / aspect_ratio) | ||
# else: | ||
# new_w, new_h = int(512 * aspect_ratio), 512 | ||
# resized = Image.fromarray((img * 255).astype(np.uint8)).resize((new_w, new_h), Image.LANCZOS) | ||
# resized = np.array(resized).astype(np.float32) / 255.0 | ||
|
||
# # Crop to 512x512 | ||
# h, w, _ = resized.shape | ||
# top = (h - 512) // 2 | ||
# left = (w - 512) // 2 | ||
# cropped = resized[top:top+512, left:left+512, :] | ||
|
||
# resized_images.append(cropped) | ||
# return resized_images | ||
|
||
@classmethod | ||
def IS_CHANGED(cls, image, use_random_image, random_folder, n_images, seed, sort, loop_sequence): | ||
if use_random_image: | ||
return seed # Return seed to indicate change when using random images | ||
else: | ||
image_path = folder_paths.get_annotated_filepath(image) | ||
m = hashlib.sha256() | ||
with open(image_path, 'rb') as f: | ||
m.update(f.read()) | ||
return m.digest().hex() | ||
|
||
@classmethod | ||
def VALIDATE_INPUTS(cls, image, use_random_image, random_folder, n_images, seed, sort, loop_sequence): | ||
if not use_random_image: | ||
if not folder_paths.exists_annotated_filepath(image): | ||
return "Invalid image file: {}".format(image) | ||
else: | ||
if not os.path.isdir(random_folder): | ||
return "Invalid folder path: {}".format(random_folder) | ||
return True | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
"LoadImagePlus": LoadImagePlus | ||
} | ||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"LoadImagePlus": "Load Image Plus" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.