diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/PixelOE b/PixelOE index 9e80eec..94e09f3 160000 --- a/PixelOE +++ b/PixelOE @@ -1 +1 @@ -Subproject commit 9e80eec44d261ba12debce3cffb1bca9c1e3f0a4 +Subproject commit 94e09f345e8469a5361943b8c6aa2d03d1ad9984 diff --git a/nodes.py b/nodes.py index 579b657..136d690 100644 --- a/nodes.py +++ b/nodes.py @@ -1,8 +1,12 @@ import cv2 import numpy as np import torch +import torchvision -from .PixelOE.pixeloe.pixelize import pixelize +from .PixelOE.src.pixeloe.legacy.pixelize import pixelize +from .PixelOE.src.pixeloe.torch import env as pixeloe_env +from .PixelOE.src.pixeloe.torch.pixelize import pixelize as pixelize_torch +from .PixelOE.src.pixeloe.torch.utils import pre_resize class PixelOE: @@ -20,31 +24,31 @@ def INPUT_TYPES(cls): "min": 0, "max": 4096, "step": 1, - "display": "number" + "display": "number", }), "patch_size": ("INT", { "default": 6, "min": 0, "max": 4096, "step": 1, - "display": "number" + "display": "number", }), "pixel_size": ("INT", { "default": 0, "min": 0, "max": 4096, "step": 1, - "display": "number" + "display": "number", }), "thickness": ("INT", { "default": 1, "min": 0, "max": 4096, "step": 1, - "display": "number" + "display": "number", }), "color_matching": ("BOOLEAN", { - "default": False + "default": False, }), "contrast": ("FLOAT", { "default": 1.0, @@ -52,7 +56,7 @@ def INPUT_TYPES(cls): "max": 10.0, "step": 0.01, "round": 0.001, - "display": "number" + "display": "number", }), "saturation": ("FLOAT", { "default": 1.0, @@ -60,24 +64,24 @@ def INPUT_TYPES(cls): "max": 10.0, "step": 0.01, "round": 0.001, - "display": "number" + "display": "number", }), "colors": ("INT", { "default": 256, "min": 1, "max": 256, "step": 1, - "display": "number" + "display": "number", }), "color_quant_method": (["kmeans", "maxcover"],), "colors_with_weight": ("BOOLEAN", { - "default": False + "default": False, }), "no_upscale": ("BOOLEAN", { - "default": False + "default": False, }), "no_downscale": ("BOOLEAN", { - "default": False + "default": False, }), }, } @@ -111,10 +115,102 @@ def process(img, mode, target_size, patch_size, pixel_size, thickness, img_pix_t = torch.from_numpy(img_pix_t)[None,] return (img_pix_t,) +class PixelOETorch: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "img": ("IMAGE",), + "mode": (["contrast", "center", "k-centroid", "bicubic", "nearest"],), + "target_size": ("INT", { + "default": 256, + "min": 0, + "max": 4096, + "step": 1, + "display": "number", + }), + "patch_size": ("INT", { + "default": 6, + "min": 0, + "max": 4096, + "step": 1, + "display": "number", + }), + "pixel_size": ("INT", { + "default": 6, + "min": 0, + "max": 4096, + "step": 1, + "display": "number", + }), + "thickness": ("INT", { + "default": 3, + "min": 0, + "max": 4096, + "step": 1, + "display": "number", + }), + "do_color_match": ("BOOLEAN", { + "default": True, + }), + "do_quant": ("BOOLEAN", { + "default": False, + }), + "num_colors": ("INT", { + "default": 32, + "min": 1, + "max": 256, + "step": 1, + "display": "number", + }), + "quant_mode": (["kmeans", "weighted-kmeans", "repeat-kmeans"],), + "dither_mode": (["ordered", "error_diffusion", "no"],), + "torch_compile": ("BOOLEAN", { + "default": True, + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + FUNCTION = "process_torch" + + CATEGORY = "image/pixelize" + + @staticmethod + def process_torch(img, target_size, patch_size, pixel_size, thickness, mode, do_color_match, + do_quant, num_colors, quant_mode, dither_mode, torch_compile): + + if pixel_size == 0: + pixel_size = None + + pixeloe_env.TORCH_COMPILE = torch_compile + + img_t = img.squeeze().permute(2, 0, 1) + img_pil = torchvision.transforms.functional.to_pil_image(img_t) + img_t = pre_resize( + img_pil=img_pil, + target_size=target_size, + patch_size=patch_size, + ) + img_pix_t = pixelize_torch( + img_t=img_t, pixel_size=pixel_size, thickness=thickness, mode=mode, + do_color_match=do_color_match, do_quant=do_quant, num_colors=num_colors, + quant_mode=quant_mode, dither_mode=dither_mode, + ) + + img_pix_t = img_pix_t.permute(0, 2, 3, 1) + return (img_pix_t,) + NODE_CLASS_MAPPINGS = { "PixelOE": PixelOE, + "PixelOETorch": PixelOETorch, } NODE_DISPLAY_NAME_MAPPINGS = { "PixelOE": "PixelOE", + "PixelOETorch": "PixelOETorch", }