diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py
index 60fd909881b..2d425a75151 100644
--- a/invokeai/app/services/shared/graph.py
+++ b/invokeai/app/services/shared/graph.py
@@ -51,15 +51,18 @@ class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
+ def __str__(self):
+ return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}"
-def get_output_field(node: BaseInvocation, field: str) -> Any:
+
+def get_output_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
-def get_input_field(node: BaseInvocation, field: str) -> Any:
+def get_input_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
@@ -93,6 +96,10 @@ def is_list_or_contains_list(t):
return False
+def is_any(t: Any) -> bool:
+ return t == Any or Any in get_args(t)
+
+
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
@@ -102,13 +109,7 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
- if (
- from_type == to_type
- or from_type == Any
- or to_type == Any
- or Any in get_args(from_type)
- or Any in get_args(to_type)
- ):
+ if from_type == to_type or is_any(from_type) or is_any(to_type):
return True
if from_type in get_args(to_type):
@@ -140,10 +141,10 @@ def are_connections_compatible(
"""Determines if a connection between fields of two nodes is compatible."""
# TODO: handle iterators and collectors
- from_node_field = get_output_field(from_node, from_field)
- to_node_field = get_input_field(to_node, to_field)
+ from_type = get_output_field_type(from_node, from_field)
+ to_type = get_input_field_type(to_node, to_field)
- return are_connection_types_compatible(from_node_field, to_node_field)
+ return are_connection_types_compatible(from_type, to_type)
T = TypeVar("T")
@@ -440,17 +441,19 @@ def validate_self(self) -> None:
self.get_node(edge.destination.node_id),
edge.destination.field,
):
- raise InvalidEdgeError(
- f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
- )
+ raise InvalidEdgeError(f"Edge source and target types do not match ({edge})")
# Validate all iterators & collectors
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
for node in self.nodes.values():
- if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id):
- raise InvalidEdgeError(f"Invalid iterator node {node.id}")
- if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id):
- raise InvalidEdgeError(f"Invalid collector node {node.id}")
+ if isinstance(node, IterateInvocation):
+ err = self._is_iterator_connection_valid(node.id)
+ if err is not None:
+ raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}")
+ if isinstance(node, CollectInvocation):
+ err = self._is_collector_connection_valid(node.id)
+ if err is not None:
+ raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}")
return None
@@ -477,11 +480,11 @@ def is_valid(self) -> bool:
def _is_destination_field_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
- return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any
+ return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == Any
def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
- return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
+ return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
@@ -491,55 +494,40 @@ def _validate_edge(self, edge: Edge):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
- raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
+ raise InvalidEdgeError(f"One or both nodes don't exist ({edge})")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
- raise InvalidEdgeError(
- f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
- )
+ raise InvalidEdgeError(f"Edge already exists ({edge})")
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
- raise InvalidEdgeError(
- f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
- )
+ raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})")
# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
- raise InvalidEdgeError(
- f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
- )
+ raise InvalidEdgeError(f"Field types are incompatible ({edge})")
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
- if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
- raise InvalidEdgeError(
- f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
- )
+ err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
+ if err is not None:
+ raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
- if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
- raise InvalidEdgeError(
- f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
- )
+ err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
+ if err is not None:
+ raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
- if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
- raise InvalidEdgeError(
- f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
- )
-
- # Validate that we are not connecting collector to iterator (currently unsupported)
- if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
- raise InvalidEdgeError(
- f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
- )
+ err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
+ if err is not None:
+ raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
if (
@@ -548,10 +536,9 @@ def _validate_edge(self, edge: Edge):
and not self._is_destination_field_list_of_Any(edge)
and not self._is_destination_field_Any(edge)
):
- if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
- raise InvalidEdgeError(
- f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
- )
+ err = self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination)
+ if err is not None:
+ raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}")
def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph."""
@@ -634,7 +621,7 @@ def _is_iterator_connection_valid(
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
- ) -> bool:
+ ) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
@@ -645,29 +632,47 @@ def _is_iterator_connection_valid(
# Only one input is allowed for iterators
if len(inputs) > 1:
- return False
+ return "Iterator may only have one input edge"
+
+ input_node = self.get_node(inputs[0].node_id)
# Get input and output fields (the fields linked to the iterator's input/output)
- input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
- output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
+ input_field_type = get_output_field_type(input_node, inputs[0].field)
+ output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
- if get_origin(input_field) is not list:
- return False
+ if get_origin(input_field_type) is not list:
+ return "Iterator input must be a collection"
# Validate that all outputs match the input type
- input_field_item_type = get_args(input_field)[0]
- if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
- return False
+ input_field_item_type = get_args(input_field_type)[0]
+ if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
+ return "Iterator outputs must connect to an input with a matching type"
+
+ # Collector input type must match all iterator output types
+ if isinstance(input_node, CollectInvocation):
+ # Traverse the graph to find the first collector input edge. Collectors validate that their collection
+ # inputs are all of the same type, so we can use the first input edge to determine the collector's type
+ first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
+ first_collector_input_type = get_output_field_type(
+ self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
+ )
+ resolved_collector_type = (
+ first_collector_input_type
+ if get_origin(first_collector_input_type) is None
+ else get_args(first_collector_input_type)
+ )
+ if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
+ return "Iterator collection type must match all iterator output types"
- return True
+ return None
def _is_collector_connection_valid(
self,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
- ) -> bool:
+ ) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
@@ -677,38 +682,42 @@ def _is_collector_connection_valid(
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
- input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
- output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
+ input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
+ output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
# Validate that all inputs are derived from or match a single type
input_field_types = {
- t
- for input_field in input_fields
- for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
- if t != NoneType
+ resolved_type
+ for input_field_type in input_field_types
+ for resolved_type in (
+ [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
+ )
+ if resolved_type != NoneType
} # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
- return False # There is more than one root type
+ return "Collector input collection items must be of a single type"
# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
- if not all(is_list_or_contains_list(f) for f in output_fields):
- return False
+ if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
+ return "Collector output must connect to a collection input"
# Verify that all outputs match the input type (are a base class or the same class)
if not all(
- is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
- for f in output_fields
+ is_any(t)
+ or is_union_subtype(input_root_type, get_args(t)[0])
+ or issubclass(input_root_type, get_args(t)[0])
+ for t in output_field_types
):
- return False
+ return "Collector outputs must connect to a collection input with a matching type"
- return True
+ return None
def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""
diff --git a/invokeai/frontend/web/public/locales/de.json b/invokeai/frontend/web/public/locales/de.json
index 56c172a3879..e67f84fe2ed 100644
--- a/invokeai/frontend/web/public/locales/de.json
+++ b/invokeai/frontend/web/public/locales/de.json
@@ -127,7 +127,6 @@
"autoAssignBoardOnClick": "Board per Klick automatisch zuweisen",
"noImageSelected": "Kein Bild ausgewählt",
"starImage": "Bild markieren",
- "assets": "Ressourcen",
"unstarImage": "Markierung entfernen",
"image": "Bild",
"deleteSelection": "Lösche Auswahl",
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 5b7840444e7..0ce49fa6700 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -305,8 +305,8 @@
},
"gallery": {
"gallery": "Gallery",
- "alwaysShowImageSizeBadge": "Always Show Image Size Badge",
"assets": "Assets",
+ "alwaysShowImageSizeBadge": "Always Show Image Size Badge",
"assetsTab": "Files you’ve uploaded for use in your projects.",
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
"autoSwitchNewImages": "Auto-Switch to New Images",
@@ -1248,6 +1248,8 @@
"problemCopyingLayer": "Unable to Copy Layer",
"problemSavingLayer": "Unable to Save Layer",
"problemDownloadingImage": "Unable to Download Image",
+ "pasteSuccess": "Pasted to {{destination}}",
+ "pasteFailed": "Paste Failed",
"prunedQueue": "Pruned Queue",
"sentToCanvas": "Sent to Canvas",
"sentToUpscale": "Sent to Upscale",
@@ -1816,6 +1818,14 @@
"newControlLayer": "New $t(controlLayers.controlLayer)",
"newInpaintMask": "New $t(controlLayers.inpaintMask)",
"newRegionalGuidance": "New $t(controlLayers.regionalGuidance)",
+ "pasteTo": "Paste To",
+ "pasteToAssets": "Assets",
+ "pasteToAssetsDesc": "Paste to Assets",
+ "pasteToBbox": "Bbox",
+ "pasteToBboxDesc": "New Layer (in Bbox)",
+ "pasteToCanvas": "Canvas",
+ "pasteToCanvasDesc": "New Layer (in Canvas)",
+ "pastedTo": "Pasted to {{destination}}",
"transparency": "Transparency",
"enableTransparencyEffect": "Enable Transparency Effect",
"disableTransparencyEffect": "Disable Transparency Effect",
@@ -2235,11 +2245,12 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
- "Low-VRAM mode",
- "Dynamic memory management",
- "Faster model loading times",
- "Fewer memory errors",
- "Expanded workflow batch capabilities"
+ "Improved VRAM setting defaults",
+ "On-demand model cache clearing",
+ "Expanded FLUX LoRA compatibility",
+ "Canvas Adjust Image filter",
+ "Cancel all but current queue item",
+ "Copy from and paste to Canvas"
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json
index a42af953ce4..324d34442e3 100644
--- a/invokeai/frontend/web/public/locales/es.json
+++ b/invokeai/frontend/web/public/locales/es.json
@@ -109,7 +109,6 @@
"deleteImage_many": "Eliminar {{count}} Imágenes",
"deleteImage_other": "Eliminar {{count}} Imágenes",
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
- "assets": "Activos",
"autoAssignBoardOnClick": "Asignar automática tableros al hacer clic",
"gallery": "Galería",
"noImageSelected": "Sin imágenes seleccionadas",
diff --git a/invokeai/frontend/web/public/locales/fr.json b/invokeai/frontend/web/public/locales/fr.json
index 31fe41716e7..182dea7e6a5 100644
--- a/invokeai/frontend/web/public/locales/fr.json
+++ b/invokeai/frontend/web/public/locales/fr.json
@@ -114,7 +114,6 @@
"sortDirection": "Direction de tri",
"sideBySide": "Côte-à-Côte",
"hover": "Au passage de la souris",
- "assets": "Ressources",
"alwaysShowImageSizeBadge": "Toujours montrer le badge de taille de l'Image",
"gallery": "Galerie",
"bulkDownloadRequestFailed": "Problème lors de la préparation du téléchargement",
diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json
index 166be4c40ed..fb797fe4395 100644
--- a/invokeai/frontend/web/public/locales/it.json
+++ b/invokeai/frontend/web/public/locales/it.json
@@ -116,7 +116,6 @@
"deleteImage_many": "Elimina {{count}} immagini",
"deleteImage_other": "Elimina {{count}} immagini",
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
- "assets": "Risorse",
"autoAssignBoardOnClick": "Assegna automaticamente la bacheca al clic",
"featuresWillReset": "Se elimini questa immagine, quelle funzionalità verranno immediatamente ripristinate.",
"loading": "Caricamento in corso",
@@ -1132,7 +1131,11 @@
"generation": "Generazione",
"other": "Altro",
"gallery": "Galleria",
- "batchSize": "Dimensione del lotto"
+ "batchSize": "Dimensione del lotto",
+ "cancelAllExceptCurrentQueueItemAlertDialog2": "Vuoi davvero annullare tutti gli elementi in coda in sospeso?",
+ "confirm": "Conferma",
+ "cancelAllExceptCurrentQueueItemAlertDialog": "L'annullamento di tutti gli elementi della coda, eccetto quello corrente, interromperà gli elementi in sospeso ma consentirà il completamento di quello in corso.",
+ "cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1963,6 +1966,25 @@
"noise_type": "Tipo di rumore",
"label": "Aggiungi rumore",
"noise_amount": "Quantità"
+ },
+ "adjust_image": {
+ "description": "Regola il canale selezionato di un'immagine.",
+ "alpha": "Alfa (RGBA)",
+ "label": "Regola l'immagine",
+ "blue": "Blu (RGBA)",
+ "luminosity": "Luminosità (LAB)",
+ "channel": "Canale",
+ "value_setting": "Valore",
+ "scale_values": "Scala i valori",
+ "red": "Rosso (RGBA)",
+ "green": "Verde (RGBA)",
+ "cyan": "Ciano (CMYK)",
+ "magenta": "Magenta (CMYK)",
+ "yellow": "Giallo (CMYK)",
+ "black": "Nero (CMYK)",
+ "hue": "Tonalità (HSV)",
+ "saturation": "Saturazione (HSV)",
+ "value": "Valore (HSV)"
}
},
"controlLayers_withCount_hidden": "Livelli di controllo ({{count}} nascosti)",
diff --git a/invokeai/frontend/web/public/locales/ja.json b/invokeai/frontend/web/public/locales/ja.json
index 4ae9d69713b..b425a9178fe 100644
--- a/invokeai/frontend/web/public/locales/ja.json
+++ b/invokeai/frontend/web/public/locales/ja.json
@@ -106,7 +106,6 @@
"featuresWillReset": "この画像を削除すると、これらの機能は即座にリセットされます。",
"unstarImage": "スターを外す",
"loading": "ロード中",
- "assets": "アセット",
"currentlyInUse": "この画像は現在下記の機能を使用しています:",
"drop": "ドロップ",
"dropOrUpload": "$t(gallery.drop) またはアップロード",
diff --git a/invokeai/frontend/web/public/locales/ko.json b/invokeai/frontend/web/public/locales/ko.json
index a79e7286dfe..d3cd12c8959 100644
--- a/invokeai/frontend/web/public/locales/ko.json
+++ b/invokeai/frontend/web/public/locales/ko.json
@@ -68,7 +68,6 @@
"gallerySettings": "갤러리 설정",
"deleteSelection": "선택 항목 삭제",
"featuresWillReset": "이 이미지를 삭제하면 해당 기능이 즉시 재설정됩니다.",
- "assets": "자산",
"noImagesInGallery": "보여줄 이미지가 없음",
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
"loading": "불러오는 중",
diff --git a/invokeai/frontend/web/public/locales/nl.json b/invokeai/frontend/web/public/locales/nl.json
index 263e2e725eb..bf1f8cee035 100644
--- a/invokeai/frontend/web/public/locales/nl.json
+++ b/invokeai/frontend/web/public/locales/nl.json
@@ -90,7 +90,6 @@
"deleteImage_one": "Verwijder afbeelding",
"deleteImage_other": "",
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
- "assets": "Eigen onderdelen",
"autoAssignBoardOnClick": "Ken automatisch bord toe bij klikken",
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
"loading": "Bezig met laden",
diff --git a/invokeai/frontend/web/public/locales/pl.json b/invokeai/frontend/web/public/locales/pl.json
index 17dac8b104f..375212c35e1 100644
--- a/invokeai/frontend/web/public/locales/pl.json
+++ b/invokeai/frontend/web/public/locales/pl.json
@@ -105,7 +105,6 @@
"assetsTab": "Pliki, które wrzuciłeś do użytku w twoich projektach.",
"currentlyInUse": "Ten obraz jest obecnie w użyciu przez następujące funkcje:",
"boardsSettings": "Ustawienia tablic",
- "assets": "Aktywy",
"autoAssignBoardOnClick": "Automatycznie przypisz tablicę po kliknięciu",
"copy": "Kopiuj"
},
diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json
index 962db313dac..5df86501121 100644
--- a/invokeai/frontend/web/public/locales/ru.json
+++ b/invokeai/frontend/web/public/locales/ru.json
@@ -106,7 +106,6 @@
"deleteImage_one": "Удалить изображение",
"deleteImage_few": "Удалить {{count}} изображения",
"deleteImage_many": "Удалить {{count}} изображений",
- "assets": "Ресурсы",
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
"deleteSelection": "Удалить выделенное",
"featuresWillReset": "Если вы удалите это изображение, эти функции будут немедленно сброшены.",
diff --git a/invokeai/frontend/web/public/locales/tr.json b/invokeai/frontend/web/public/locales/tr.json
index 22119932bf2..ac78d8442f3 100644
--- a/invokeai/frontend/web/public/locales/tr.json
+++ b/invokeai/frontend/web/public/locales/tr.json
@@ -195,7 +195,6 @@
},
"gallery": {
"deleteImagePermanent": "Silinen görseller geri getirilemez.",
- "assets": "Özkaynaklar",
"autoAssignBoardOnClick": "Tıklanan Panoya Otomatik Atama",
"loading": "Yükleniyor",
"starImage": "Yıldız Koy",
diff --git a/invokeai/frontend/web/public/locales/vi.json b/invokeai/frontend/web/public/locales/vi.json
index 20acaa2be67..b398a46e9ac 100644
--- a/invokeai/frontend/web/public/locales/vi.json
+++ b/invokeai/frontend/web/public/locales/vi.json
@@ -86,7 +86,6 @@
"bulkDownloadRequestedDesc": "Yêu cầu tải xuống đang được chuẩn bị. Vui lòng chờ trong giây lát.",
"starImage": "Gắn Sao Cho Ảnh",
"openViewer": "Mở Trình Xem",
- "assets": "Tài Nguyên",
"viewerImage": "Trình Xem Ảnh",
"sideBySide": "Cạnh Nhau",
"alwaysShowImageSizeBadge": "Luôn Hiển Thị Kích Thước Ảnh",
diff --git a/invokeai/frontend/web/public/locales/zh_CN.json b/invokeai/frontend/web/public/locales/zh_CN.json
index 3a28e82c870..d1cb520b7ea 100644
--- a/invokeai/frontend/web/public/locales/zh_CN.json
+++ b/invokeai/frontend/web/public/locales/zh_CN.json
@@ -107,7 +107,6 @@
"noImagesInGallery": "无图像可用于显示",
"deleteImage_other": "删除{{count}}张图片",
"deleteImagePermanent": "删除的图片无法被恢复。",
- "assets": "素材",
"autoAssignBoardOnClick": "点击后自动分配面板",
"featuresWillReset": "如果您删除该图像,这些功能会立即被重置。",
"loading": "加载中",
diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx
index 7bef8d59a6d..69a79f1d835 100644
--- a/invokeai/frontend/web/src/app/components/App.tsx
+++ b/invokeai/frontend/web/src/app/components/App.tsx
@@ -12,10 +12,12 @@ import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
+import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import {
NewCanvasSessionDialog,
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
+import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
@@ -112,6 +114,9 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
+
+
+
);
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts
index e2fd33ecf30..484d1f6882c 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts
@@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
-import { truncate, upperFirst } from 'lodash-es';
+import { truncate } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';
@@ -52,15 +52,12 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
const result = zPydanticValidationError.safeParse(response);
if (result.success) {
result.data.data.detail.map((e) => {
+ const description = truncate(e.msg.replace(/^(Value|Index|Key) error, /i, ''), { length: 256 });
toast({
id: 'QUEUE_BATCH_FAILED',
- title: truncate(upperFirst(e.msg), { length: 128 }),
+ title: t('queue.batchFailedToQueue'),
status: 'error',
- description: truncate(
- `Path:
- ${e.loc.join('.')}`,
- { length: 128 }
- ),
+ description,
});
});
} else if (response.status !== 403) {
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx
index f66709ef59b..ca264fa389c 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx
@@ -2,8 +2,8 @@ import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-l
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
+import { useCopyCanvasToClipboard } from 'features/controlLayers/hooks/copyHooks';
import {
- useCopyCanvasToClipboard,
useNewControlLayerFromBbox,
useNewGlobalReferenceImageFromBbox,
useNewRasterLayerFromBbox,
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasPasteModal.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasPasteModal.tsx
new file mode 100644
index 00000000000..4f64a7808e9
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasPasteModal.tsx
@@ -0,0 +1,150 @@
+import {
+ Button,
+ Flex,
+ Modal,
+ ModalBody,
+ ModalCloseButton,
+ ModalContent,
+ ModalFooter,
+ ModalHeader,
+ ModalOverlay,
+} from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
+import { useAppStore } from 'app/store/nanostores/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
+import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
+import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
+import { createNewCanvasEntityFromImage } from 'features/imageActions/actions';
+import { toast } from 'features/toast/toast';
+import { atom } from 'nanostores';
+import { memo, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiBoundingBoxBold, PiImageBold } from 'react-icons/pi';
+import { useUploadImageMutation } from 'services/api/endpoints/images';
+
+const $imageFile = atom(null);
+export const setFileToPaste = (file: File) => $imageFile.set(file);
+const clearFileToPaste = () => $imageFile.set(null);
+
+export const CanvasPasteModal = memo(() => {
+ useAssertSingleton('CanvasPasteModal');
+ const { dispatch, getState } = useAppStore();
+ const { t } = useTranslation();
+ const imageToPaste = useStore($imageFile);
+ const canvasManager = useCanvasManager();
+ const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
+ const [uploadImage, { isLoading }] = useUploadImageMutation({ fixedCacheKey: 'canvasPasteModal' });
+
+ const getPosition = useCallback(
+ (destination: 'canvas' | 'bbox') => {
+ const { x, y } = canvasManager.stateApi.getBbox().rect;
+ if (destination === 'bbox') {
+ return { x, y };
+ }
+ const rasterLayerAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
+ if (rasterLayerAdapters.length === 0) {
+ return { x, y };
+ }
+ {
+ const { x, y } = canvasManager.compositor.getRectOfAdapters(rasterLayerAdapters);
+ return { x, y };
+ }
+ },
+ [canvasManager.compositor, canvasManager.stateApi]
+ );
+
+ const handlePaste = useCallback(
+ async (file: File, destination: 'assets' | 'canvas' | 'bbox') => {
+ try {
+ const is_intermediate = destination !== 'assets';
+ const imageDTO = await uploadImage({
+ file,
+ is_intermediate,
+ image_category: 'user',
+ board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
+ }).unwrap();
+
+ if (destination !== 'assets') {
+ createNewCanvasEntityFromImage({
+ type: 'raster_layer',
+ imageDTO,
+ dispatch,
+ getState,
+ overrides: { position: getPosition(destination) },
+ });
+ }
+ } catch {
+ toast({
+ title: t('toast.pasteFailed'),
+ status: 'error',
+ });
+ } finally {
+ clearFileToPaste();
+ toast({
+ title: t('toast.pasteSuccess', {
+ destination:
+ destination === 'assets'
+ ? t('controlLayers.pasteToAssets')
+ : destination === 'bbox'
+ ? t('controlLayers.pasteToBbox')
+ : t('controlLayers.pasteToCanvas'),
+ }),
+ status: 'success',
+ });
+ }
+ },
+ [autoAddBoardId, dispatch, getPosition, getState, t, uploadImage]
+ );
+
+ const pasteToAssets = useCallback(() => {
+ if (!imageToPaste) {
+ return;
+ }
+ handlePaste(imageToPaste, 'assets');
+ }, [handlePaste, imageToPaste]);
+
+ const pasteToCanvas = useCallback(() => {
+ if (!imageToPaste) {
+ return;
+ }
+ handlePaste(imageToPaste, 'canvas');
+ }, [handlePaste, imageToPaste]);
+
+ const pasteToBbox = useCallback(() => {
+ if (!imageToPaste) {
+ return;
+ }
+ handlePaste(imageToPaste, 'bbox');
+ }, [handlePaste, imageToPaste]);
+
+ return (
+
+
+
+ {t('controlLayers.pasteTo')}
+
+
+
+ }>
+ {t('controlLayers.pasteToCanvasDesc')}
+
+ }>
+ {t('controlLayers.pasteToBboxDesc')}
+
+
+
+
+
+
+
+
+
+ );
+});
+
+CanvasPasteModal.displayName = 'CanvasPasteModal';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx
index c718e4dc9d3..78940e3140d 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx
@@ -1,8 +1,8 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
+import { useCopyLayerToClipboard } from 'features/controlLayers/hooks/copyHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
-import { useCopyLayerToClipboard } from 'features/controlLayers/hooks/useCopyLayerToClipboard';
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCopyLayerToClipboard.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/copyHooks.ts
similarity index 56%
rename from invokeai/frontend/web/src/features/controlLayers/hooks/useCopyLayerToClipboard.ts
rename to invokeai/frontend/web/src/features/controlLayers/hooks/copyHooks.ts
index aedb4637967..35552b5937a 100644
--- a/invokeai/frontend/web/src/features/controlLayers/hooks/useCopyLayerToClipboard.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/hooks/copyHooks.ts
@@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import { withResultAsync } from 'common/util/result';
+import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
@@ -7,6 +8,7 @@ import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers
import { canvasToBlob } from 'features/controlLayers/konva/util';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
+import { startCase } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
@@ -53,3 +55,39 @@ export const useCopyLayerToClipboard = () => {
return copyLayerToCipboard;
};
+
+export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => {
+ const { t } = useTranslation();
+ const canvasManager = useCanvasManager();
+ const copyCanvasToClipboard = useCallback(async () => {
+ const rect =
+ region === 'bbox'
+ ? canvasManager.stateApi.getBbox().rect
+ : canvasManager.compositor.getVisibleRectOfType('raster_layer');
+
+ if (rect.width === 0 || rect.height === 0) {
+ toast({
+ title: t('controlLayers.copyRegionError', { region: startCase(region) }),
+ description: t('controlLayers.regionIsEmpty'),
+ status: 'warning',
+ });
+ return;
+ }
+
+ const result = await withResultAsync(async () => {
+ const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
+ const canvasElement = canvasManager.compositor.getCompositeCanvas(rasterAdapters, rect);
+ const blob = await canvasToBlob(canvasElement);
+ copyBlobToClipboard(blob);
+ });
+
+ if (result.isOk()) {
+ toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) });
+ } else {
+ log.error({ error: serializeError(result.error) }, 'Failed to save canvas to gallery');
+ toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' });
+ }
+ }, [canvasManager.compositor, canvasManager.stateApi, region, t]);
+
+ return copyCanvasToClipboard;
+};
diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts
index 26a7dbb5d42..e337c31f5ab 100644
--- a/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts
@@ -4,7 +4,7 @@ import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
-import { canvasToBlob, getPrefixedId } from 'features/controlLayers/konva/util';
+import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
entityRasterized,
@@ -27,9 +27,7 @@ import type {
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import type { BoardId } from 'features/gallery/store/types';
-import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
-import { startCase } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
@@ -152,42 +150,6 @@ export const useSaveBboxToGallery = () => {
return func;
};
-export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => {
- const { t } = useTranslation();
- const canvasManager = useCanvasManager();
- const copyCanvasToClipboard = useCallback(async () => {
- const rect =
- region === 'bbox'
- ? canvasManager.stateApi.getBbox().rect
- : canvasManager.compositor.getVisibleRectOfType('raster_layer');
-
- if (rect.width === 0 || rect.height === 0) {
- toast({
- title: t('controlLayers.copyRegionError', { region: startCase(region) }),
- description: t('controlLayers.regionIsEmpty'),
- status: 'warning',
- });
- return;
- }
-
- const result = await withResultAsync(async () => {
- const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
- const canvasElement = canvasManager.compositor.getCompositeCanvas(rasterAdapters, rect);
- const blob = await canvasToBlob(canvasElement);
- copyBlobToClipboard(blob);
- });
-
- if (result.isOk()) {
- toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) });
- } else {
- log.error({ error: serializeError(result.error) }, 'Failed to save canvas to gallery');
- toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' });
- }
- }, [canvasManager.compositor, canvasManager.stateApi, region, t]);
-
- return copyCanvasToClipboard;
-};
-
export const useNewRegionalReferenceImageFromBbox = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
diff --git a/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx b/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx
index 139a6af8bf3..d922faa0965 100644
--- a/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx
+++ b/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx
@@ -4,13 +4,17 @@ import { containsFiles, getFiles } from '@atlaskit/pragmatic-drag-and-drop/exter
import { preventUnhandled } from '@atlaskit/pragmatic-drag-and-drop/prevent-unhandled';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Heading } from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
import { getStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
+import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
+import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
+import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { uploadImages } from 'services/api/endpoints/images';
@@ -71,6 +75,8 @@ export const FullscreenDropzone = memo(() => {
const ref = useRef(null);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState('idle');
+ const activeTab = useAppSelector(selectActiveTab);
+ const isImageViewerOpen = useStore($imageViewer);
const validateAndUploadFiles = useCallback(
(files: File[]) => {
@@ -92,6 +98,15 @@ export const FullscreenDropzone = memo(() => {
});
return;
}
+
+ // While on the canvas tab and when pasting a single image, canvas may want to create a new layer. Let it handle
+ // the paste event.
+ const [firstImageFile] = files;
+ if (!isImageViewerOpen && activeTab === 'canvas' && files.length === 1 && firstImageFile) {
+ setFileToPaste(firstImageFile);
+ return;
+ }
+
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
@@ -104,7 +119,7 @@ export const FullscreenDropzone = memo(() => {
uploadImages(uploadArgs);
},
- [maxImageUploadCount, t]
+ [activeTab, isImageViewerOpen, maxImageUploadCount, t]
);
const onPaste = useCallback(
diff --git a/invokeai/frontend/web/src/features/imageActions/actions.ts b/invokeai/frontend/web/src/features/imageActions/actions.ts
index d9822f4ed4d..dfb5fec2f2b 100644
--- a/invokeai/frontend/web/src/features/imageActions/actions.ts
+++ b/invokeai/frontend/web/src/features/imageActions/actions.ts
@@ -19,12 +19,12 @@ import { selectBboxModelBase, selectBboxRect } from 'features/controlLayers/stor
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
- CanvasEntityState,
CanvasEntityType,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasRenderableEntityIdentifier,
+ CanvasRenderableEntityState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
@@ -82,7 +82,7 @@ export const createNewCanvasEntityFromImage = (arg: {
type: CanvasEntityType | 'regional_guidance_with_reference_image';
dispatch: AppDispatch;
getState: () => RootState;
- overrides?: Partial>;
+ overrides?: Partial>;
}) => {
const { type, imageDTO, dispatch, getState, overrides: _overrides } = arg;
const state = getState();
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts
index b5cdfd39ee1..755ff4ea382 100644
--- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts
@@ -63,14 +63,6 @@ describe(validateConnectionTypes.name, () => {
});
describe('special cases', () => {
- it('should reject a COLLECTION input to a COLLECTION input', () => {
- const r = validateConnectionTypes(
- { name: 'CollectionField', cardinality: 'COLLECTION', batch: false },
- { name: 'CollectionField', cardinality: 'COLLECTION', batch: false }
- );
- expect(r).toBe(false);
- });
-
it('should accept equal types', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts
index 44f632490a4..835bf83af03 100644
--- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts
@@ -8,13 +8,6 @@ import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'fe
* @returns True if the connection is valid, false otherwise.
*/
export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => {
- // TODO: There's a bug with Collect -> Iterate nodes:
- // https://github.com/invoke-ai/InvokeAI/issues/3956
- // Once this is resolved, we can remove this check.
- if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
- return false;
- }
-
if (areTypesEqual(sourceType, targetType)) {
return true;
}
diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py
index a1e2271ea9d..b6775b71ae3 100644
--- a/invokeai/version/invokeai_version.py
+++ b/invokeai/version/invokeai_version.py
@@ -1 +1 @@
-__version__ = "5.6.0"
+__version__ = "5.6.1rc1"
diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py
index b0358c08bad..0a4ce775384 100644
--- a/tests/test_node_graph.py
+++ b/tests/test_node_graph.py
@@ -9,7 +9,9 @@
invocation,
invocation_output,
)
+from invokeai.app.invocations.math import AddInvocation
from invokeai.app.invocations.primitives import (
+ ColorInvocation,
FloatCollectionInvocation,
FloatInvocation,
IntegerInvocation,
@@ -689,9 +691,6 @@ def test_any_accepts_any():
def test_iterate_accepts_collection():
- """We need to update the validation for Collect -> Iterate to traverse to the Iterate
- node's output and compare that against the item type of the Collect node's collection. Until
- then, Collect nodes may not output into Iterate nodes."""
g = Graph()
n1 = IntegerInvocation(id="1", value=1)
n2 = IntegerInvocation(id="2", value=2)
@@ -706,9 +705,36 @@ def test_iterate_accepts_collection():
e3 = create_edge(n3.id, "collection", n4.id, "collection")
g.add_edge(e1)
g.add_edge(e2)
- # Once we fix the validation logic as described, this should should not raise an error
- with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
- g.add_edge(e3)
+ g.add_edge(e3)
+
+
+def test_iterate_validates_collection_inputs_against_iterator_outputs():
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = IntegerInvocation(id="2", value=2)
+ n3 = CollectInvocation(id="3")
+ n4 = IterateInvocation(id="4")
+ n5 = AddInvocation(id="5")
+ g.add_node(n1)
+ g.add_node(n2)
+ g.add_node(n3)
+ g.add_node(n4)
+ g.add_node(n5)
+ e1 = create_edge(n1.id, "value", n3.id, "item")
+ e2 = create_edge(n2.id, "value", n3.id, "item")
+ e3 = create_edge(n3.id, "collection", n4.id, "collection")
+ e4 = create_edge(n4.id, "item", n5.id, "a")
+ g.add_edge(e1)
+ g.add_edge(e2)
+ g.add_edge(e3)
+ # Not throwing on this line indicates the collector's input types validated successfully against the iterator's output types
+ g.add_edge(e4)
+ with pytest.raises(InvalidEdgeError, match="Iterator collection type must match all iterator output types"):
+ # Connect iterator to a node with a different type than the collector inputs which is not allowed
+ n6 = ColorInvocation(id="6")
+ g.add_node(n6)
+ e5 = create_edge(n4.id, "item", n6.id, "color")
+ g.add_edge(e5)
def test_graph_can_generate_schema():