-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Got tired of going outside of comfy to find where nodes were so made the text search node.
- Loading branch information
Showing
7 changed files
with
448 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,41 @@ | ||
# __init__.py | ||
from .text_append_node import NODE_CLASS_MAPPINGS as TEXT_APPEND_NODE_CLASS_MAPPINGS | ||
from .text_append_node import NODE_DISPLAY_NAME_MAPPINGS as TEXT_APPEND_NODE_DISPLAY_NAME_MAPPINGS | ||
from .vramdebugplus import NODE_CLASS_MAPPINGS as VRAM_DEBUG_PLUS_NODE_CLASS_MAPPINGS | ||
from .vramdebugplus import NODE_DISPLAY_NAME_MAPPINGS as VRAM_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS | ||
from .tensordebugplus import NODE_CLASS_MAPPINGS as TENSOR_DEBUG_PLUS_NODE_CLASS_MAPPINGS | ||
from .tensordebugplus import NODE_DISPLAY_NAME_MAPPINGS as TENSOR_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS | ||
from .animation_schedule_output import NODE_CLASS_MAPPINGS as ANIMATION_SCHEDULE_OUTPUT_NODE_CLASS_MAPPINGS | ||
from .animation_schedule_output import NODE_DISPLAY_NAME_MAPPINGS as ANIMATION_SCHEDULE_OUTPUT_NODE_DISPLAY_NAME_MAPPINGS | ||
#WIP from .clipinterrogator import NODE_CLASS_MAPPINGS as CLIP_INTERROGATOR_NODE_CLASS_MAPPINGS | ||
#WIP from .clipinterrogator import NODE_DISPLAY_NAME_MAPPINGS as CLIP_INTERROGATOR_NODE_DISPLAY_NAME_MAPPINGS | ||
#WIP from .image_batcher import NODE_CLASS_MAPPINGS as IMAGE_BATCHER_NODE_CLASS_MAPPINGS | ||
#WIP from .image_batcher import NODE_DISPLAY_NAME_MAPPINGS as IMAGE_BATCHER_NODE_DISPLAY_NAME_MAPPINGS | ||
from .text_search_node import NODE_CLASS_MAPPINGS as TEXT_SEARCH_NODE_CLASS_MAPPINGS | ||
from .text_search_node import NODE_DISPLAY_NAME_MAPPINGS as TEXT_SEARCH_NODE_DISPLAY_NAME_MAPPINGS | ||
#WIP from .custom_emoji_menu_item_node import NODE_CLASS_MAPPINGS as CUSTOM_NODE_CLASS_MAPPINGS | ||
#WIP from .custom_emoji_menu_item_node import NODE_DISPLAY_NAME_MAPPINGS as CUSTOM_NODE_DISPLAY_NAME_MAPPINGS | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
**TEXT_APPEND_NODE_CLASS_MAPPINGS, | ||
**VRAM_DEBUG_PLUS_NODE_CLASS_MAPPINGS, | ||
**TENSOR_DEBUG_PLUS_NODE_CLASS_MAPPINGS, | ||
**ANIMATION_SCHEDULE_OUTPUT_NODE_CLASS_MAPPINGS, | ||
#WIP **CLIP_INTERROGATOR_NODE_CLASS_MAPPINGS, | ||
#WIP **IMAGE_BATCHER_NODE_CLASS_MAPPINGS, | ||
**TEXT_SEARCH_NODE_CLASS_MAPPINGS, | ||
#WIP **CUSTOM_NODE_CLASS_MAPPINGS # Add the new node class mappings | ||
} | ||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
**TEXT_APPEND_NODE_DISPLAY_NAME_MAPPINGS, | ||
**VRAM_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS, | ||
**TENSOR_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS, | ||
**ANIMATION_SCHEDULE_OUTPUT_NODE_DISPLAY_NAME_MAPPINGS, | ||
#WIP **CLIP_INTERROGATOR_NODE_DISPLAY_NAME_MAPPINGS, | ||
#WIP **IMAGE_BATCHER_NODE_DISPLAY_NAME_MAPPINGS, | ||
**TEXT_SEARCH_NODE_DISPLAY_NAME_MAPPINGS, | ||
#WIP **CUSTOM_NODE_DISPLAY_NAME_MAPPINGS # Add the new node display name mappings | ||
} | ||
|
||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] | ||
# __init__.py | ||
from .text_append_node import NODE_CLASS_MAPPINGS as TEXT_APPEND_NODE_CLASS_MAPPINGS | ||
from .text_append_node import NODE_DISPLAY_NAME_MAPPINGS as TEXT_APPEND_NODE_DISPLAY_NAME_MAPPINGS | ||
from .vramdebugplus import NODE_CLASS_MAPPINGS as VRAM_DEBUG_PLUS_NODE_CLASS_MAPPINGS | ||
from .vramdebugplus import NODE_DISPLAY_NAME_MAPPINGS as VRAM_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS | ||
from .tensordebugplus import NODE_CLASS_MAPPINGS as TENSOR_DEBUG_PLUS_NODE_CLASS_MAPPINGS | ||
from .tensordebugplus import NODE_DISPLAY_NAME_MAPPINGS as TENSOR_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS | ||
from .animation_schedule_output import NODE_CLASS_MAPPINGS as ANIMATION_SCHEDULE_OUTPUT_NODE_CLASS_MAPPINGS | ||
from .animation_schedule_output import NODE_DISPLAY_NAME_MAPPINGS as ANIMATION_SCHEDULE_OUTPUT_NODE_DISPLAY_NAME_MAPPINGS | ||
#WIP from .clipinterrogator import NODE_CLASS_MAPPINGS as CLIP_INTERROGATOR_NODE_CLASS_MAPPINGS | ||
#WIP from .clipinterrogator import NODE_DISPLAY_NAME_MAPPINGS as CLIP_INTERROGATOR_NODE_DISPLAY_NAME_MAPPINGS | ||
#WIP from .image_batcher import NODE_CLASS_MAPPINGS as IMAGE_BATCHER_NODE_CLASS_MAPPINGS | ||
#WIP from .image_batcher import NODE_DISPLAY_NAME_MAPPINGS as IMAGE_BATCHER_NODE_DISPLAY_NAME_MAPPINGS | ||
from .text_search_node import NODE_CLASS_MAPPINGS as TEXT_SEARCH_NODE_CLASS_MAPPINGS | ||
from .text_search_node import NODE_DISPLAY_NAME_MAPPINGS as TEXT_SEARCH_NODE_DISPLAY_NAME_MAPPINGS | ||
#WIP from .custom_emoji_menu_item_node import NODE_CLASS_MAPPINGS as CUSTOM_NODE_CLASS_MAPPINGS | ||
#WIP from .custom_emoji_menu_item_node import NODE_DISPLAY_NAME_MAPPINGS as CUSTOM_NODE_DISPLAY_NAME_MAPPINGS | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
**TEXT_APPEND_NODE_CLASS_MAPPINGS, | ||
**VRAM_DEBUG_PLUS_NODE_CLASS_MAPPINGS, | ||
**TENSOR_DEBUG_PLUS_NODE_CLASS_MAPPINGS, | ||
**ANIMATION_SCHEDULE_OUTPUT_NODE_CLASS_MAPPINGS, | ||
#WIP **CLIP_INTERROGATOR_NODE_CLASS_MAPPINGS, | ||
#WIP **IMAGE_BATCHER_NODE_CLASS_MAPPINGS, | ||
**TEXT_SEARCH_NODE_CLASS_MAPPINGS, | ||
#WIP **CUSTOM_NODE_CLASS_MAPPINGS # Add the new node class mappings | ||
} | ||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
**TEXT_APPEND_NODE_DISPLAY_NAME_MAPPINGS, | ||
**VRAM_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS, | ||
**TENSOR_DEBUG_PLUS_NODE_DISPLAY_NAME_MAPPINGS, | ||
**ANIMATION_SCHEDULE_OUTPUT_NODE_DISPLAY_NAME_MAPPINGS, | ||
#WIP **CLIP_INTERROGATOR_NODE_DISPLAY_NAME_MAPPINGS, | ||
#WIP **IMAGE_BATCHER_NODE_DISPLAY_NAME_MAPPINGS, | ||
**TEXT_SEARCH_NODE_DISPLAY_NAME_MAPPINGS, | ||
#WIP **CUSTOM_NODE_DISPLAY_NAME_MAPPINGS # Add the new node display name mappings | ||
} | ||
|
||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import torch | ||
import open_clip | ||
from clip_interrogator import Config, Interrogator | ||
from PIL import Image | ||
import os | ||
import numpy as np | ||
|
||
class CLIPInterrogatorNode: | ||
CATEGORY = "π§π»ββοΈπ° πͺ πΌ π° " | ||
|
||
def __init__(self): | ||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
self.interrogator = None | ||
self.current_model = None | ||
|
||
@classmethod | ||
def INPUT_TYPES(cls): | ||
return { | ||
"required": { | ||
"image": ("IMAGE",), | ||
"clip_model_name": (["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k", "ViT-bigG-14/laion2b_s39b", "OpenCLIP ViT-G/14"],), | ||
"mode": (["best", "fast", "classic", "negative"],), | ||
"save_text": ("BOOLEAN", {"default": False}), | ||
"keep_model_loaded": ("BOOLEAN", {"default": False}), | ||
"output_dir": ("STRING", {"default": "same as image"}), | ||
}, | ||
} | ||
|
||
RETURN_TYPES = ("STRING",) | ||
FUNCTION = "interrogate_image" | ||
RETURN_NAMES = ("prompt",) | ||
|
||
def load_interrogator(self, clip_model_name): | ||
if self.interrogator is None or clip_model_name != self.current_model: | ||
config = Config( | ||
clip_model_name=clip_model_name, | ||
device=self.device | ||
) | ||
self.interrogator = Interrogator(config) | ||
self.current_model = clip_model_name | ||
|
||
def unload_interrogator(self): | ||
if self.interrogator is not None: | ||
del self.interrogator | ||
self.interrogator = None | ||
self.current_model = None | ||
torch.cuda.empty_cache() | ||
|
||
def interrogate_image(self, image, clip_model_name, mode, save_text, keep_model_loaded, output_dir): | ||
if not self.validate_inputs(image, clip_model_name, mode, save_text, keep_model_loaded, output_dir): | ||
return ("Error: Invalid inputs",) | ||
|
||
self.load_interrogator(clip_model_name) | ||
|
||
try: | ||
# Convert the image to PIL Image using the new method | ||
pil_image = self.comfy_tensor_to_pil(image) | ||
|
||
# Use the interrogator's methods directly | ||
if mode == 'best': | ||
result = self.interrogator.interrogate(pil_image) | ||
elif mode == 'fast': | ||
result = self.interrogator.interrogate_fast(pil_image) | ||
elif mode == 'classic': | ||
result = self.interrogator.interrogate_classic(pil_image) | ||
elif mode == 'negative': | ||
result = self.interrogator.interrogate_negative(pil_image) | ||
else: | ||
raise ValueError(f"Unknown mode: {mode}") | ||
|
||
if save_text: | ||
self.save_text_file("image", result, output_dir, image) | ||
|
||
if not keep_model_loaded: | ||
self.unload_interrogator() | ||
|
||
return (result,) | ||
except Exception as e: | ||
print(f"Error in CLIP Interrogator: {str(e)}") | ||
return (f"Error: Unable to generate prompt. {str(e)}",) | ||
|
||
def comfy_tensor_to_pil(self, tensor): | ||
# Ensure the tensor is on CPU and detached from the computation graph | ||
tensor = tensor.cpu().detach() | ||
|
||
# Convert to numpy array | ||
image_np = tensor.numpy() | ||
|
||
# Squeeze out any singleton dimensions | ||
image_np = np.squeeze(image_np) | ||
|
||
# Ensure the image has 3 dimensions (H, W, C) | ||
if image_np.ndim == 2: | ||
image_np = np.expand_dims(image_np, axis=-1) | ||
|
||
# If the image is grayscale, convert to RGB | ||
if image_np.shape[-1] == 1: | ||
image_np = np.repeat(image_np, 3, axis=-1) | ||
|
||
# Ensure the values are in the range [0, 255] | ||
if image_np.max() <= 1.0: | ||
image_np = (image_np * 255).astype(np.uint8) | ||
else: | ||
image_np = image_np.astype(np.uint8) | ||
|
||
# Create PIL Image | ||
return Image.fromarray(image_np) | ||
|
||
def save_text_file(self, image_name, prompt, output_dir, image): | ||
if output_dir == "same as image" or not output_dir: | ||
# Assume the image is from a "Load Image" node, which provides metadata | ||
if hasattr(image, 'already_saved_as'): | ||
output_dir = os.path.dirname(image.already_saved_as) | ||
else: | ||
output_dir = os.getcwd() # Fallback to current working directory | ||
|
||
file_path = os.path.join(output_dir, f"{image_name}_prompt.txt") | ||
with open(file_path, 'w', encoding='utf-8') as f: | ||
f.write(prompt) | ||
|
||
def validate_inputs(self, image, clip_model_name, mode, save_text, keep_model_loaded, output_dir): | ||
if not isinstance(image, torch.Tensor): | ||
print(f"Invalid image input. Expected a torch.Tensor, got {type(image)}.") | ||
return False | ||
if clip_model_name not in ["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k", "ViT-bigG-14/laion2b_s39b", "OpenCLIP ViT-G/14"]: | ||
print("Invalid CLIP model name.") | ||
return False | ||
if mode not in ["best", "fast", "classic", "negative"]: | ||
print("Invalid interrogation mode.") | ||
return False | ||
if not isinstance(save_text, bool): | ||
print("Invalid save_text input. Expected a boolean.") | ||
return False | ||
if not isinstance(keep_model_loaded, bool): | ||
print("Invalid keep_model_loaded input. Expected a boolean.") | ||
return False | ||
if not isinstance(output_dir, str): | ||
print("Invalid output_dir input. Expected a string.") | ||
return False | ||
return True | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
"CLIPInterrogator": CLIPInterrogatorNode | ||
} | ||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"CLIPInterrogator": "CLIP Interrogator" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import folder_paths | ||
from server import PromptServer | ||
import nodes | ||
|
||
class CustomNodeWithEmojiMenuItem: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": {"image": ("IMAGE",)}} | ||
|
||
RETURN_TYPES = ("IMAGE",) | ||
FUNCTION = "process_image" | ||
CATEGORY = "π§π»ββοΈπ° πͺ πΌ π° " | ||
|
||
def process_image(self, image): | ||
# Your image processing logic here | ||
return (image,) | ||
|
||
@classmethod | ||
def add_custom_menu_item(cls): | ||
try: | ||
from comfy.graph import NodeIdNode | ||
except ImportError as e: | ||
print(f"ImportError: {e}") | ||
return | ||
|
||
def do_nothing(node_id): | ||
# This function does nothing | ||
pass | ||
|
||
# Check if the PromptServer has the method add_context_menu_option | ||
if hasattr(PromptServer.instance, 'add_context_menu_option'): | ||
PromptServer.instance.add_context_menu_option('π§π»ββοΈπ° πͺ πΌ π°', do_nothing) | ||
PromptServer.instance.add_context_menu_option('Hello World!', do_nothing) | ||
else: | ||
print("PromptServer does not have the method add_context_menu_option") | ||
|
||
# Register the custom menu items | ||
CustomNodeWithEmojiMenuItem.add_custom_menu_item() | ||
|
||
# Register the node | ||
NODE_CLASS_MAPPINGS = { | ||
"CustomNodeWithEmojiMenuItem": CustomNodeWithEmojiMenuItem | ||
} | ||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"CustomNodeWithEmojiMenuItem": "Custom Node With Emoji Menu Item" | ||
} |
Oops, something went wrong.