Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions modules/python/src/custom_nodes/google_genmedia/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@

AUDIO_MIME_TYPES = ["audio/mp3", "audio/wav", "audio/mpeg"]
GEMINI_USER_AGENT = "cloud-solutions/comfyui-gemini-custom-node-v1"
GEMINI_25_FLASH_IMAGE_ASPECT_RATIO = [
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
]
GEMINI_25_FLASH_IMAGE_MAX_OUTPUT_TOKEN = 32768
GEMINI_25_FLASH_IMAGE_USER_AGENT = (
"cloud-solutions/comfyui-gemini-25-flash-image-custom-node-v1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, project_id: Optional[str] = None, region: Optional[str] = Non
def generate_image(
self,
model: str,
aspect_ratio: str,
prompt: str,
temperature: float,
top_p: float,
Expand All @@ -66,12 +67,15 @@ def generate_image(
sexually_explicit_threshold: str,
dangerous_content_threshold: str,
system_instruction: str,
image: Optional[torch.Tensor] = None,
image1: torch.Tensor,
image2: Optional[torch.Tensor] = None,
image3: Optional[torch.Tensor] = None,
Comment thread
gushob21 marked this conversation as resolved.
) -> List[Image.Image]:
"""Generates an image using the Gemini Flash Image model.

Args:
model: The name of the Gemini model to use. default: gemini-2.5-flash-image-preview
aspect_ratio: The desired aspect ratio of the output image.
prompt: The text prompt for image generation.
temperature: Controls randomness in token generation.
top_p: The cumulative probability of tokens to consider for sampling.
Expand All @@ -82,8 +86,9 @@ def generate_image(
content.
dangerous_content_threshold: Safety threshold for dangerous content.
system_instruction: System-level instructions for the model.
image: An optional input image tensor for image-to-image tasks.
Defaults to None.
image1: The primary input image tensor for image-to-image tasks.
image2: An optional second input image tensor. Defaults to None.
image3: An optional third input image tensor. Defaults to None.

Returns:
A list of generated PIL images.
Expand All @@ -102,6 +107,9 @@ def generate_image(
top_k=top_k,
max_output_tokens=GEMINI_25_FLASH_IMAGE_MAX_OUTPUT_TOKEN,
response_modalities=["TEXT", "IMAGE"],
image_config=types.ImageConfig(
aspect_ratio=aspect_ratio,
),
system_instruction=system_instruction,
safety_settings=[
types.SafetySetting(
Expand Down Expand Up @@ -139,15 +147,15 @@ def generate_image(

contents = [types.Part.from_text(text=prompt)]

if image != None:
num_images = image.shape[0]
print(f"Number of Images {num_images}")
for i in range(num_images):
image_tensor = image[i].unsqueeze(0)
image_to_b64 = utils.tensor_to_pil_to_base64(image_tensor)
contents.append(
types.Part.from_bytes(data=image_to_b64, mime_type="image/png")
)
for i, image_tensor in enumerate([image1, image2, image3]):
if image_tensor is not None:
for j in range(image_tensor.shape[0]):
single_image = image_tensor[j].unsqueeze(0)
image_bytes = utils.tensor_to_pil_to_bytes(single_image)
contents.append(
types.Part.from_bytes(data=image_bytes, mime_type="image/png")
)
print(f"Appended image {i+1}, part {j+1} to contents.")

response = self.client.models.generate_content(
model=model, contents=contents, config=generate_content_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
import numpy as np
import torch

from .constants import GeminiFlashImageModel, ThresholdOptions
from .constants import (
GEMINI_25_FLASH_IMAGE_ASPECT_RATIO,
GeminiFlashImageModel,
ThresholdOptions,
)
from .custom_exceptions import APIExecutionError, APIInputError, ConfigurationError
from .gemini_flash_image_api import GeminiFlashImageAPI

Expand Down Expand Up @@ -56,6 +60,22 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
"default": "A vivid landscape painting of a futuristic city",
},
),
"image1": ("IMAGE",),
Comment thread
gushob21 marked this conversation as resolved.
"aspect_ratio": (
[
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
Comment thread
gushob21 marked this conversation as resolved.
],
{"default": "16:9"},
),
"temperature": (
"FLOAT",
{"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01},
Expand All @@ -67,7 +87,8 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
"top_k": ("INT", {"default": 32, "min": 1, "max": 64}),
},
"optional": {
"image": ("IMAGE",),
"image2": ("IMAGE",),
"image3": ("IMAGE",),
# Safety Settings
"harassment_threshold": (
[threshold_option.name for threshold_option in ThresholdOptions],
Expand Down Expand Up @@ -119,16 +140,19 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
def generate_and_return_image(
self,
model: str,
aspect_ratio: str,
prompt: str,
temperature: float,
top_p: float,
top_k: int,
image1: torch.Tensor,
hate_speech_threshold: str,
harassment_threshold: str,
sexually_explicit_threshold: str,
dangerous_content_threshold: str,
system_instruction: str,
image: Optional[torch.Tensor] = None,
image2: Optional[torch.Tensor] = None,
image3: Optional[torch.Tensor] = None,
gcp_project_id: Optional[str] = None,
gcp_region: Optional[str] = None,
) -> Tuple[torch.Tensor,]:
Expand All @@ -140,6 +164,7 @@ def generate_and_return_image(

Args:
model: The Gemini Flash Image model to use. default: gemini-2.5-flash-image-preview
aspect_ratio: The desired aspect ratio of the output image.
prompt: The text prompt for image generation.
temperature: Controls randomness in token generation.
top_p: The cumulative probability of tokens to consider for sampling.
Expand All @@ -150,8 +175,9 @@ def generate_and_return_image(
content.
dangerous_content_threshold: Safety threshold for dangerous content.
system_instruction: System-level instructions for the model.
image: An optional input image tensor for image editing tasks.
Defaults to None.
image1: The primary input image tensor for image editing tasks.
image2: An optional second input image tensor. Defaults to None.
image3: An optional third input image tensor. Defaults to None.
gcp_project_id: The GCP project ID.
gcp_region: The GCP region.

Expand All @@ -170,23 +196,27 @@ def generate_and_return_image(
raise RuntimeError(
f"Gemini Flash Image API Configuration Error: {e}"
) from e

if image != None:
print(type(image))
if aspect_ratio not in GEMINI_25_FLASH_IMAGE_ASPECT_RATIO:
raise RuntimeError(
f"Invalid aspect ratio: {aspect_ratio}. Valid aspect ratios are: {GEMINI_25_FLASH_IMAGE_ASPECT_RATIO}."
)

try:
pil_images = gemini_flash_image_api.generate_image(
Comment thread
gushob21 marked this conversation as resolved.
model,
prompt,
temperature,
top_p,
top_k,
hate_speech_threshold,
harassment_threshold,
sexually_explicit_threshold,
dangerous_content_threshold,
system_instruction,
image,
model=model,
aspect_ratio=aspect_ratio,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
hate_speech_threshold=hate_speech_threshold,
harassment_threshold=harassment_threshold,
sexually_explicit_threshold=sexually_explicit_threshold,
dangerous_content_threshold=dangerous_content_threshold,
system_instruction=system_instruction,
image1=image1,
image2=image2,
image3=image3,
)
except APIInputError as e:
raise RuntimeError(f"Image generation input error: {e}") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
google-api-core==2.24.2
google-cloud-aiplatform==1.111.0
google-cloud-storage==2.19.0
google-genai==1.32.0
google-genai==1.46.0
google-generativeai==0.8.5
moviepy==2.2.1
opencv-python-headless==4.11.0.86
49 changes: 31 additions & 18 deletions modules/python/src/custom_nodes/google_genmedia/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,34 @@ def validate_gcs_uri_and_image(
return False, f"An unexpected error occurred during GCS validation: {e}"


def tensor_to_pil_to_bytes(image: torch.tensor, format="PNG") -> bytes:
"""Converts a PyTorch tensor or PIL Image into PNG-encoded bytes.

This function processes an input image, which can be either a PyTorch tensor
or a PIL Image object. If the input is a tensor, it is first converted to a
PIL Image. The function then saves the final PIL Image as a PNG into an
in-memory buffer and returns its raw byte content.

Args:
image (torch.Tensor | PIL.Image.Image): The input image. If it's a
PyTorch tensor, it is expected to have a shape like (1, H, W, C)
and float values in the [0, 1] range.

Returns:
bytes: The raw bytes of the image, encoded in PNG format.
"""
pil_image: PIL_Image.Image
if isinstance(image, torch.Tensor):
image_np = (image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
pil_image = PIL_Image.fromarray(image_np)
else:
pil_image = image

buffered = io.BytesIO()
pil_image.save(buffered, format=format)
return buffered.getvalue()


def tensor_to_pil_to_base64(image: torch.tensor, format="PNG") -> bytes:
"""Converts a PyTorch tensor or PIL Image into PNG-encoded bytes.

Expand All @@ -990,21 +1018,6 @@ def tensor_to_pil_to_base64(image: torch.tensor, format="PNG") -> bytes:
"""

pil_image: PIL_Image.Image
image_input_bytes: bytes
try:
if isinstance(image, torch.Tensor):
image_np = (image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
pil_image = PIL_Image.fromarray(image_np)
print("Converted input image tensor to PIL Image for Base64 encoding.")
else:
pil_image = image
print(f"Using input image as is for Base64 (type: {type(image)}).")

buffered = io.BytesIO()
pil_image.save(buffered, format=format)
image_input_bytes = buffered.getvalue()
image_base64 = base64.b64encode(image_input_bytes).decode("utf-8")
return image_base64
except Exception as e:
print(f"Cant convert the image to base64 {e}")
print(f"Cant convert the image to base64 {e}")
image_input_bytes = tensor_to_pil_to_bytes(image, format)
image_base64 = base64.b64encode(image_input_bytes).decode("utf-8")
return image_base64
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading