diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 60fd909881b..2d425a75151 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -51,15 +51,18 @@ class Edge(BaseModel): source: EdgeConnection = Field(description="The connection for the edge's from node and field") destination: EdgeConnection = Field(description="The connection for the edge's to node and field") + def __str__(self): + return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}" -def get_output_field(node: BaseInvocation, field: str) -> Any: + +def get_output_field_type(node: BaseInvocation, field: str) -> Any: node_type = type(node) node_outputs = get_type_hints(node_type.get_output_annotation()) node_output_field = node_outputs.get(field) or None return node_output_field -def get_input_field(node: BaseInvocation, field: str) -> Any: +def get_input_field_type(node: BaseInvocation, field: str) -> Any: node_type = type(node) node_inputs = get_type_hints(node_type) node_input_field = node_inputs.get(field) or None @@ -93,6 +96,10 @@ def is_list_or_contains_list(t): return False +def is_any(t: Any) -> bool: + return t == Any or Any in get_args(t) + + def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type: return False @@ -102,13 +109,7 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: # TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such) if from_type and to_type: # Ports are compatible - if ( - from_type == to_type - or from_type == Any - or to_type == Any - or Any in get_args(from_type) - or Any in get_args(to_type) - ): + if from_type == to_type or is_any(from_type) or is_any(to_type): return True if from_type in get_args(to_type): @@ -140,10 +141,10 @@ def are_connections_compatible( """Determines if a connection between fields of two nodes is compatible.""" # TODO: handle iterators and collectors - from_node_field = get_output_field(from_node, from_field) - to_node_field = get_input_field(to_node, to_field) + from_type = get_output_field_type(from_node, from_field) + to_type = get_input_field_type(to_node, to_field) - return are_connection_types_compatible(from_node_field, to_node_field) + return are_connection_types_compatible(from_type, to_type) T = TypeVar("T") @@ -440,17 +441,19 @@ def validate_self(self) -> None: self.get_node(edge.destination.node_id), edge.destination.field, ): - raise InvalidEdgeError( - f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" - ) + raise InvalidEdgeError(f"Edge source and target types do not match ({edge})") # Validate all iterators & collectors # TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available for node in self.nodes.values(): - if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id): - raise InvalidEdgeError(f"Invalid iterator node {node.id}") - if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id): - raise InvalidEdgeError(f"Invalid collector node {node.id}") + if isinstance(node, IterateInvocation): + err = self._is_iterator_connection_valid(node.id) + if err is not None: + raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}") + if isinstance(node, CollectInvocation): + err = self._is_collector_connection_valid(node.id) + if err is not None: + raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}") return None @@ -477,11 +480,11 @@ def is_valid(self) -> bool: def _is_destination_field_Any(self, edge: Edge) -> bool: """Checks if the destination field for an edge is of type typing.Any""" - return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any + return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == Any def _is_destination_field_list_of_Any(self, edge: Edge) -> bool: """Checks if the destination field for an edge is of type typing.Any""" - return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any] + return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any] def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" @@ -491,55 +494,40 @@ def _validate_edge(self, edge: Edge): from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) except NodeNotFoundError: - raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}") + raise InvalidEdgeError(f"One or both nodes don't exist ({edge})") # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): - raise InvalidEdgeError( - f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists" - ) + raise InvalidEdgeError(f"Edge already exists ({edge})") # Validate that no cycles would be created g = self.nx_graph_flat() g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): - raise InvalidEdgeError( - f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}" - ) + raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})") # Validate that the field types are compatible if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field): - raise InvalidEdgeError( - f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" - ) + raise InvalidEdgeError(f"Field types are incompatible ({edge})") # Validate if iterator output type matches iterator input type (if this edge results in both being set) if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": - if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source): - raise InvalidEdgeError( - f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" - ) + err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source) + if err is not None: + raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}") # Validate if iterator input type matches output type (if this edge results in both being set) if isinstance(from_node, IterateInvocation) and edge.source.field == "item": - if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination): - raise InvalidEdgeError( - f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" - ) + err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination) + if err is not None: + raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}") # Validate if collector input type matches output type (if this edge results in both being set) if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": - if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source): - raise InvalidEdgeError( - f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" - ) - - # Validate that we are not connecting collector to iterator (currently unsupported) - if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation): - raise InvalidEdgeError( - f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" - ) + err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source) + if err is not None: + raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}") # Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any] if ( @@ -548,10 +536,9 @@ def _validate_edge(self, edge: Edge): and not self._is_destination_field_list_of_Any(edge) and not self._is_destination_field_Any(edge) ): - if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination): - raise InvalidEdgeError( - f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" - ) + err = self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination) + if err is not None: + raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}") def has_node(self, node_id: str) -> bool: """Determines whether or not a node exists in the graph.""" @@ -634,7 +621,7 @@ def _is_iterator_connection_valid( node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, - ) -> bool: + ) -> str | None: inputs = [e.source for e in self._get_input_edges(node_id, "collection")] outputs = [e.destination for e in self._get_output_edges(node_id, "item")] @@ -645,29 +632,47 @@ def _is_iterator_connection_valid( # Only one input is allowed for iterators if len(inputs) > 1: - return False + return "Iterator may only have one input edge" + + input_node = self.get_node(inputs[0].node_id) # Get input and output fields (the fields linked to the iterator's input/output) - input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) - output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] + input_field_type = get_output_field_type(input_node, inputs[0].field) + output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] # Input type must be a list - if get_origin(input_field) is not list: - return False + if get_origin(input_field_type) is not list: + return "Iterator input must be a collection" # Validate that all outputs match the input type - input_field_item_type = get_args(input_field)[0] - if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)): - return False + input_field_item_type = get_args(input_field_type)[0] + if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)): + return "Iterator outputs must connect to an input with a matching type" + + # Collector input type must match all iterator output types + if isinstance(input_node, CollectInvocation): + # Traverse the graph to find the first collector input edge. Collectors validate that their collection + # inputs are all of the same type, so we can use the first input edge to determine the collector's type + first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0] + first_collector_input_type = get_output_field_type( + self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field + ) + resolved_collector_type = ( + first_collector_input_type + if get_origin(first_collector_input_type) is None + else get_args(first_collector_input_type) + ) + if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)): + return "Iterator collection type must match all iterator output types" - return True + return None def _is_collector_connection_valid( self, node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, - ) -> bool: + ) -> str | None: inputs = [e.source for e in self._get_input_edges(node_id, "item")] outputs = [e.destination for e in self._get_output_edges(node_id, "collection")] @@ -677,38 +682,42 @@ def _is_collector_connection_valid( outputs.append(new_output) # Get input and output fields (the fields linked to the iterator's input/output) - input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs] - output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] + input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs] + output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] # Validate that all inputs are derived from or match a single type input_field_types = { - t - for input_field in input_fields - for t in ([input_field] if get_origin(input_field) is None else get_args(input_field)) - if t != NoneType + resolved_type + for input_field_type in input_field_types + for resolved_type in ( + [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type) + ) + if resolved_type != NoneType } # Get unique types type_tree = nx.DiGraph() type_tree.add_nodes_from(input_field_types) type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) type_degrees = type_tree.in_degree(type_tree.nodes) if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore - return False # There is more than one root type + return "Collector input collection items must be of a single type" # Get the input root type input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore # Verify that all outputs are lists - if not all(is_list_or_contains_list(f) for f in output_fields): - return False + if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types): + return "Collector output must connect to a collection input" # Verify that all outputs match the input type (are a base class or the same class) if not all( - is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0]) - for f in output_fields + is_any(t) + or is_union_subtype(input_root_type, get_args(t)[0]) + or issubclass(input_root_type, get_args(t)[0]) + for t in output_field_types ): - return False + return "Collector outputs must connect to a collection input with a matching type" - return True + return None def nx_graph(self) -> nx.DiGraph: """Returns a NetworkX DiGraph representing the layout of this graph""" diff --git a/invokeai/frontend/web/public/locales/de.json b/invokeai/frontend/web/public/locales/de.json index 56c172a3879..e67f84fe2ed 100644 --- a/invokeai/frontend/web/public/locales/de.json +++ b/invokeai/frontend/web/public/locales/de.json @@ -127,7 +127,6 @@ "autoAssignBoardOnClick": "Board per Klick automatisch zuweisen", "noImageSelected": "Kein Bild ausgewählt", "starImage": "Bild markieren", - "assets": "Ressourcen", "unstarImage": "Markierung entfernen", "image": "Bild", "deleteSelection": "Lösche Auswahl", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 5b7840444e7..0ce49fa6700 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -305,8 +305,8 @@ }, "gallery": { "gallery": "Gallery", - "alwaysShowImageSizeBadge": "Always Show Image Size Badge", "assets": "Assets", + "alwaysShowImageSizeBadge": "Always Show Image Size Badge", "assetsTab": "Files you’ve uploaded for use in your projects.", "autoAssignBoardOnClick": "Auto-Assign Board on Click", "autoSwitchNewImages": "Auto-Switch to New Images", @@ -1248,6 +1248,8 @@ "problemCopyingLayer": "Unable to Copy Layer", "problemSavingLayer": "Unable to Save Layer", "problemDownloadingImage": "Unable to Download Image", + "pasteSuccess": "Pasted to {{destination}}", + "pasteFailed": "Paste Failed", "prunedQueue": "Pruned Queue", "sentToCanvas": "Sent to Canvas", "sentToUpscale": "Sent to Upscale", @@ -1816,6 +1818,14 @@ "newControlLayer": "New $t(controlLayers.controlLayer)", "newInpaintMask": "New $t(controlLayers.inpaintMask)", "newRegionalGuidance": "New $t(controlLayers.regionalGuidance)", + "pasteTo": "Paste To", + "pasteToAssets": "Assets", + "pasteToAssetsDesc": "Paste to Assets", + "pasteToBbox": "Bbox", + "pasteToBboxDesc": "New Layer (in Bbox)", + "pasteToCanvas": "Canvas", + "pasteToCanvasDesc": "New Layer (in Canvas)", + "pastedTo": "Pasted to {{destination}}", "transparency": "Transparency", "enableTransparencyEffect": "Enable Transparency Effect", "disableTransparencyEffect": "Disable Transparency Effect", @@ -2235,11 +2245,12 @@ "whatsNew": { "whatsNewInInvoke": "What's New in Invoke", "items": [ - "Low-VRAM mode", - "Dynamic memory management", - "Faster model loading times", - "Fewer memory errors", - "Expanded workflow batch capabilities" + "Improved VRAM setting defaults", + "On-demand model cache clearing", + "Expanded FLUX LoRA compatibility", + "Canvas Adjust Image filter", + "Cancel all but current queue item", + "Copy from and paste to Canvas" ], "readReleaseNotes": "Read Release Notes", "watchRecentReleaseVideos": "Watch Recent Release Videos", diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json index a42af953ce4..324d34442e3 100644 --- a/invokeai/frontend/web/public/locales/es.json +++ b/invokeai/frontend/web/public/locales/es.json @@ -109,7 +109,6 @@ "deleteImage_many": "Eliminar {{count}} Imágenes", "deleteImage_other": "Eliminar {{count}} Imágenes", "deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.", - "assets": "Activos", "autoAssignBoardOnClick": "Asignar automática tableros al hacer clic", "gallery": "Galería", "noImageSelected": "Sin imágenes seleccionadas", diff --git a/invokeai/frontend/web/public/locales/fr.json b/invokeai/frontend/web/public/locales/fr.json index 31fe41716e7..182dea7e6a5 100644 --- a/invokeai/frontend/web/public/locales/fr.json +++ b/invokeai/frontend/web/public/locales/fr.json @@ -114,7 +114,6 @@ "sortDirection": "Direction de tri", "sideBySide": "Côte-à-Côte", "hover": "Au passage de la souris", - "assets": "Ressources", "alwaysShowImageSizeBadge": "Toujours montrer le badge de taille de l'Image", "gallery": "Galerie", "bulkDownloadRequestFailed": "Problème lors de la préparation du téléchargement", diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index 166be4c40ed..fb797fe4395 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -116,7 +116,6 @@ "deleteImage_many": "Elimina {{count}} immagini", "deleteImage_other": "Elimina {{count}} immagini", "deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.", - "assets": "Risorse", "autoAssignBoardOnClick": "Assegna automaticamente la bacheca al clic", "featuresWillReset": "Se elimini questa immagine, quelle funzionalità verranno immediatamente ripristinate.", "loading": "Caricamento in corso", @@ -1132,7 +1131,11 @@ "generation": "Generazione", "other": "Altro", "gallery": "Galleria", - "batchSize": "Dimensione del lotto" + "batchSize": "Dimensione del lotto", + "cancelAllExceptCurrentQueueItemAlertDialog2": "Vuoi davvero annullare tutti gli elementi in coda in sospeso?", + "confirm": "Conferma", + "cancelAllExceptCurrentQueueItemAlertDialog": "L'annullamento di tutti gli elementi della coda, eccetto quello corrente, interromperà gli elementi in sospeso ma consentirà il completamento di quello in corso.", + "cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente" }, "models": { "noMatchingModels": "Nessun modello corrispondente", @@ -1963,6 +1966,25 @@ "noise_type": "Tipo di rumore", "label": "Aggiungi rumore", "noise_amount": "Quantità" + }, + "adjust_image": { + "description": "Regola il canale selezionato di un'immagine.", + "alpha": "Alfa (RGBA)", + "label": "Regola l'immagine", + "blue": "Blu (RGBA)", + "luminosity": "Luminosità (LAB)", + "channel": "Canale", + "value_setting": "Valore", + "scale_values": "Scala i valori", + "red": "Rosso (RGBA)", + "green": "Verde (RGBA)", + "cyan": "Ciano (CMYK)", + "magenta": "Magenta (CMYK)", + "yellow": "Giallo (CMYK)", + "black": "Nero (CMYK)", + "hue": "Tonalità (HSV)", + "saturation": "Saturazione (HSV)", + "value": "Valore (HSV)" } }, "controlLayers_withCount_hidden": "Livelli di controllo ({{count}} nascosti)", diff --git a/invokeai/frontend/web/public/locales/ja.json b/invokeai/frontend/web/public/locales/ja.json index 4ae9d69713b..b425a9178fe 100644 --- a/invokeai/frontend/web/public/locales/ja.json +++ b/invokeai/frontend/web/public/locales/ja.json @@ -106,7 +106,6 @@ "featuresWillReset": "この画像を削除すると、これらの機能は即座にリセットされます。", "unstarImage": "スターを外す", "loading": "ロード中", - "assets": "アセット", "currentlyInUse": "この画像は現在下記の機能を使用しています:", "drop": "ドロップ", "dropOrUpload": "$t(gallery.drop) またはアップロード", diff --git a/invokeai/frontend/web/public/locales/ko.json b/invokeai/frontend/web/public/locales/ko.json index a79e7286dfe..d3cd12c8959 100644 --- a/invokeai/frontend/web/public/locales/ko.json +++ b/invokeai/frontend/web/public/locales/ko.json @@ -68,7 +68,6 @@ "gallerySettings": "갤러리 설정", "deleteSelection": "선택 항목 삭제", "featuresWillReset": "이 이미지를 삭제하면 해당 기능이 즉시 재설정됩니다.", - "assets": "자산", "noImagesInGallery": "보여줄 이미지가 없음", "autoSwitchNewImages": "새로운 이미지로 자동 전환", "loading": "불러오는 중", diff --git a/invokeai/frontend/web/public/locales/nl.json b/invokeai/frontend/web/public/locales/nl.json index 263e2e725eb..bf1f8cee035 100644 --- a/invokeai/frontend/web/public/locales/nl.json +++ b/invokeai/frontend/web/public/locales/nl.json @@ -90,7 +90,6 @@ "deleteImage_one": "Verwijder afbeelding", "deleteImage_other": "", "deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.", - "assets": "Eigen onderdelen", "autoAssignBoardOnClick": "Ken automatisch bord toe bij klikken", "featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.", "loading": "Bezig met laden", diff --git a/invokeai/frontend/web/public/locales/pl.json b/invokeai/frontend/web/public/locales/pl.json index 17dac8b104f..375212c35e1 100644 --- a/invokeai/frontend/web/public/locales/pl.json +++ b/invokeai/frontend/web/public/locales/pl.json @@ -105,7 +105,6 @@ "assetsTab": "Pliki, które wrzuciłeś do użytku w twoich projektach.", "currentlyInUse": "Ten obraz jest obecnie w użyciu przez następujące funkcje:", "boardsSettings": "Ustawienia tablic", - "assets": "Aktywy", "autoAssignBoardOnClick": "Automatycznie przypisz tablicę po kliknięciu", "copy": "Kopiuj" }, diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json index 962db313dac..5df86501121 100644 --- a/invokeai/frontend/web/public/locales/ru.json +++ b/invokeai/frontend/web/public/locales/ru.json @@ -106,7 +106,6 @@ "deleteImage_one": "Удалить изображение", "deleteImage_few": "Удалить {{count}} изображения", "deleteImage_many": "Удалить {{count}} изображений", - "assets": "Ресурсы", "autoAssignBoardOnClick": "Авто-назначение доски по клику", "deleteSelection": "Удалить выделенное", "featuresWillReset": "Если вы удалите это изображение, эти функции будут немедленно сброшены.", diff --git a/invokeai/frontend/web/public/locales/tr.json b/invokeai/frontend/web/public/locales/tr.json index 22119932bf2..ac78d8442f3 100644 --- a/invokeai/frontend/web/public/locales/tr.json +++ b/invokeai/frontend/web/public/locales/tr.json @@ -195,7 +195,6 @@ }, "gallery": { "deleteImagePermanent": "Silinen görseller geri getirilemez.", - "assets": "Özkaynaklar", "autoAssignBoardOnClick": "Tıklanan Panoya Otomatik Atama", "loading": "Yükleniyor", "starImage": "Yıldız Koy", diff --git a/invokeai/frontend/web/public/locales/vi.json b/invokeai/frontend/web/public/locales/vi.json index 20acaa2be67..b398a46e9ac 100644 --- a/invokeai/frontend/web/public/locales/vi.json +++ b/invokeai/frontend/web/public/locales/vi.json @@ -86,7 +86,6 @@ "bulkDownloadRequestedDesc": "Yêu cầu tải xuống đang được chuẩn bị. Vui lòng chờ trong giây lát.", "starImage": "Gắn Sao Cho Ảnh", "openViewer": "Mở Trình Xem", - "assets": "Tài Nguyên", "viewerImage": "Trình Xem Ảnh", "sideBySide": "Cạnh Nhau", "alwaysShowImageSizeBadge": "Luôn Hiển Thị Kích Thước Ảnh", diff --git a/invokeai/frontend/web/public/locales/zh_CN.json b/invokeai/frontend/web/public/locales/zh_CN.json index 3a28e82c870..d1cb520b7ea 100644 --- a/invokeai/frontend/web/public/locales/zh_CN.json +++ b/invokeai/frontend/web/public/locales/zh_CN.json @@ -107,7 +107,6 @@ "noImagesInGallery": "无图像可用于显示", "deleteImage_other": "删除{{count}}张图片", "deleteImagePermanent": "删除的图片无法被恢复。", - "assets": "素材", "autoAssignBoardOnClick": "点击后自动分配面板", "featuresWillReset": "如果您删除该图像,这些功能会立即被重置。", "loading": "加载中", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 7bef8d59a6d..69a79f1d835 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -12,10 +12,12 @@ import { useFocusRegionWatcher } from 'common/hooks/focus'; import { useClearStorage } from 'common/hooks/useClearStorage'; import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; +import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal'; import { NewCanvasSessionDialog, NewGallerySessionDialog, } from 'features/controlLayers/components/NewSessionConfirmationAlertDialog'; +import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal'; import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone'; import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal'; @@ -112,6 +114,9 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => { + + + ); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts index e2fd33ecf30..484d1f6882c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { zPydanticValidationError } from 'features/system/store/zodSchemas'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; -import { truncate, upperFirst } from 'lodash-es'; +import { truncate } from 'lodash-es'; import { serializeError } from 'serialize-error'; import { queueApi } from 'services/api/endpoints/queue'; import type { JsonObject } from 'type-fest'; @@ -52,15 +52,12 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = const result = zPydanticValidationError.safeParse(response); if (result.success) { result.data.data.detail.map((e) => { + const description = truncate(e.msg.replace(/^(Value|Index|Key) error, /i, ''), { length: 256 }); toast({ id: 'QUEUE_BATCH_FAILED', - title: truncate(upperFirst(e.msg), { length: 128 }), + title: t('queue.batchFailedToQueue'), status: 'error', - description: truncate( - `Path: - ${e.loc.join('.')}`, - { length: 128 } - ), + description, }); }); } else if (response.status !== 403) { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx index f66709ef59b..ca264fa389c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx @@ -2,8 +2,8 @@ import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-l import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu'; import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox'; import { NewLayerIcon } from 'features/controlLayers/components/common/icons'; +import { useCopyCanvasToClipboard } from 'features/controlLayers/hooks/copyHooks'; import { - useCopyCanvasToClipboard, useNewControlLayerFromBbox, useNewGlobalReferenceImageFromBbox, useNewRasterLayerFromBbox, diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasPasteModal.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasPasteModal.tsx new file mode 100644 index 00000000000..4f64a7808e9 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasPasteModal.tsx @@ -0,0 +1,150 @@ +import { + Button, + Flex, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, +} from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { useAppStore } from 'app/store/nanostores/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; +import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; +import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; +import { createNewCanvasEntityFromImage } from 'features/imageActions/actions'; +import { toast } from 'features/toast/toast'; +import { atom } from 'nanostores'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiBoundingBoxBold, PiImageBold } from 'react-icons/pi'; +import { useUploadImageMutation } from 'services/api/endpoints/images'; + +const $imageFile = atom(null); +export const setFileToPaste = (file: File) => $imageFile.set(file); +const clearFileToPaste = () => $imageFile.set(null); + +export const CanvasPasteModal = memo(() => { + useAssertSingleton('CanvasPasteModal'); + const { dispatch, getState } = useAppStore(); + const { t } = useTranslation(); + const imageToPaste = useStore($imageFile); + const canvasManager = useCanvasManager(); + const autoAddBoardId = useAppSelector(selectAutoAddBoardId); + const [uploadImage, { isLoading }] = useUploadImageMutation({ fixedCacheKey: 'canvasPasteModal' }); + + const getPosition = useCallback( + (destination: 'canvas' | 'bbox') => { + const { x, y } = canvasManager.stateApi.getBbox().rect; + if (destination === 'bbox') { + return { x, y }; + } + const rasterLayerAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer'); + if (rasterLayerAdapters.length === 0) { + return { x, y }; + } + { + const { x, y } = canvasManager.compositor.getRectOfAdapters(rasterLayerAdapters); + return { x, y }; + } + }, + [canvasManager.compositor, canvasManager.stateApi] + ); + + const handlePaste = useCallback( + async (file: File, destination: 'assets' | 'canvas' | 'bbox') => { + try { + const is_intermediate = destination !== 'assets'; + const imageDTO = await uploadImage({ + file, + is_intermediate, + image_category: 'user', + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + }).unwrap(); + + if (destination !== 'assets') { + createNewCanvasEntityFromImage({ + type: 'raster_layer', + imageDTO, + dispatch, + getState, + overrides: { position: getPosition(destination) }, + }); + } + } catch { + toast({ + title: t('toast.pasteFailed'), + status: 'error', + }); + } finally { + clearFileToPaste(); + toast({ + title: t('toast.pasteSuccess', { + destination: + destination === 'assets' + ? t('controlLayers.pasteToAssets') + : destination === 'bbox' + ? t('controlLayers.pasteToBbox') + : t('controlLayers.pasteToCanvas'), + }), + status: 'success', + }); + } + }, + [autoAddBoardId, dispatch, getPosition, getState, t, uploadImage] + ); + + const pasteToAssets = useCallback(() => { + if (!imageToPaste) { + return; + } + handlePaste(imageToPaste, 'assets'); + }, [handlePaste, imageToPaste]); + + const pasteToCanvas = useCallback(() => { + if (!imageToPaste) { + return; + } + handlePaste(imageToPaste, 'canvas'); + }, [handlePaste, imageToPaste]); + + const pasteToBbox = useCallback(() => { + if (!imageToPaste) { + return; + } + handlePaste(imageToPaste, 'bbox'); + }, [handlePaste, imageToPaste]); + + return ( + + + + {t('controlLayers.pasteTo')} + + + + + + + + + + + + + + ); +}); + +CanvasPasteModal.displayName = 'CanvasPasteModal'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx index c718e4dc9d3..78940e3140d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard.tsx @@ -1,8 +1,8 @@ import { MenuItem } from '@invoke-ai/ui-library'; import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { useCopyLayerToClipboard } from 'features/controlLayers/hooks/copyHooks'; import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy'; -import { useCopyLayerToClipboard } from 'features/controlLayers/hooks/useCopyLayerToClipboard'; import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCopyLayerToClipboard.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/copyHooks.ts similarity index 56% rename from invokeai/frontend/web/src/features/controlLayers/hooks/useCopyLayerToClipboard.ts rename to invokeai/frontend/web/src/features/controlLayers/hooks/copyHooks.ts index aedb4637967..35552b5937a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/useCopyLayerToClipboard.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/copyHooks.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import { withResultAsync } from 'common/util/result'; +import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; @@ -7,6 +8,7 @@ import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers import { canvasToBlob } from 'features/controlLayers/konva/util'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; import { toast } from 'features/toast/toast'; +import { startCase } from 'lodash-es'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { serializeError } from 'serialize-error'; @@ -53,3 +55,39 @@ export const useCopyLayerToClipboard = () => { return copyLayerToCipboard; }; + +export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => { + const { t } = useTranslation(); + const canvasManager = useCanvasManager(); + const copyCanvasToClipboard = useCallback(async () => { + const rect = + region === 'bbox' + ? canvasManager.stateApi.getBbox().rect + : canvasManager.compositor.getVisibleRectOfType('raster_layer'); + + if (rect.width === 0 || rect.height === 0) { + toast({ + title: t('controlLayers.copyRegionError', { region: startCase(region) }), + description: t('controlLayers.regionIsEmpty'), + status: 'warning', + }); + return; + } + + const result = await withResultAsync(async () => { + const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer'); + const canvasElement = canvasManager.compositor.getCompositeCanvas(rasterAdapters, rect); + const blob = await canvasToBlob(canvasElement); + copyBlobToClipboard(blob); + }); + + if (result.isOk()) { + toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) }); + } else { + log.error({ error: serializeError(result.error) }, 'Failed to save canvas to gallery'); + toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' }); + } + }, [canvasManager.compositor, canvasManager.stateApi, region, t]); + + return copyCanvasToClipboard; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts index 26a7dbb5d42..e337c31f5ab 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts @@ -4,7 +4,7 @@ import { deepClone } from 'common/util/deepClone'; import { withResultAsync } from 'common/util/result'; import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks'; -import { canvasToBlob, getPrefixedId } from 'features/controlLayers/konva/util'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { controlLayerAdded, entityRasterized, @@ -27,9 +27,7 @@ import type { import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import type { BoardId } from 'features/gallery/store/types'; -import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; import { toast } from 'features/toast/toast'; -import { startCase } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { serializeError } from 'serialize-error'; @@ -152,42 +150,6 @@ export const useSaveBboxToGallery = () => { return func; }; -export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => { - const { t } = useTranslation(); - const canvasManager = useCanvasManager(); - const copyCanvasToClipboard = useCallback(async () => { - const rect = - region === 'bbox' - ? canvasManager.stateApi.getBbox().rect - : canvasManager.compositor.getVisibleRectOfType('raster_layer'); - - if (rect.width === 0 || rect.height === 0) { - toast({ - title: t('controlLayers.copyRegionError', { region: startCase(region) }), - description: t('controlLayers.regionIsEmpty'), - status: 'warning', - }); - return; - } - - const result = await withResultAsync(async () => { - const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer'); - const canvasElement = canvasManager.compositor.getCompositeCanvas(rasterAdapters, rect); - const blob = await canvasToBlob(canvasElement); - copyBlobToClipboard(blob); - }); - - if (result.isOk()) { - toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) }); - } else { - log.error({ error: serializeError(result.error) }, 'Failed to save canvas to gallery'); - toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' }); - } - }, [canvasManager.compositor, canvasManager.stateApi, region, t]); - - return copyCanvasToClipboard; -}; - export const useNewRegionalReferenceImageFromBbox = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx b/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx index 139a6af8bf3..d922faa0965 100644 --- a/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx +++ b/invokeai/frontend/web/src/features/dnd/FullscreenDropzone.tsx @@ -4,13 +4,17 @@ import { containsFiles, getFiles } from '@atlaskit/pragmatic-drag-and-drop/exter import { preventUnhandled } from '@atlaskit/pragmatic-drag-and-drop/prevent-unhandled'; import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, Flex, Heading } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { getStore } from 'app/store/nanostores/store'; import { useAppSelector } from 'app/store/storeHooks'; +import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal'; import { DndDropOverlay } from 'features/dnd/DndDropOverlay'; import type { DndTargetState } from 'features/dnd/types'; +import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import { selectMaxImageUploadCount } from 'features/system/store/configSlice'; import { toast } from 'features/toast/toast'; +import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { memo, useCallback, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { uploadImages } from 'services/api/endpoints/images'; @@ -71,6 +75,8 @@ export const FullscreenDropzone = memo(() => { const ref = useRef(null); const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount); const [dndState, setDndState] = useState('idle'); + const activeTab = useAppSelector(selectActiveTab); + const isImageViewerOpen = useStore($imageViewer); const validateAndUploadFiles = useCallback( (files: File[]) => { @@ -92,6 +98,15 @@ export const FullscreenDropzone = memo(() => { }); return; } + + // While on the canvas tab and when pasting a single image, canvas may want to create a new layer. Let it handle + // the paste event. + const [firstImageFile] = files; + if (!isImageViewerOpen && activeTab === 'canvas' && files.length === 1 && firstImageFile) { + setFileToPaste(firstImageFile); + return; + } + const autoAddBoardId = selectAutoAddBoardId(getState()); const uploadArgs: UploadImageArg[] = files.map((file, i) => ({ @@ -104,7 +119,7 @@ export const FullscreenDropzone = memo(() => { uploadImages(uploadArgs); }, - [maxImageUploadCount, t] + [activeTab, isImageViewerOpen, maxImageUploadCount, t] ); const onPaste = useCallback( diff --git a/invokeai/frontend/web/src/features/imageActions/actions.ts b/invokeai/frontend/web/src/features/imageActions/actions.ts index d9822f4ed4d..dfb5fec2f2b 100644 --- a/invokeai/frontend/web/src/features/imageActions/actions.ts +++ b/invokeai/frontend/web/src/features/imageActions/actions.ts @@ -19,12 +19,12 @@ import { selectBboxModelBase, selectBboxRect } from 'features/controlLayers/stor import type { CanvasControlLayerState, CanvasEntityIdentifier, - CanvasEntityState, CanvasEntityType, CanvasInpaintMaskState, CanvasRasterLayerState, CanvasRegionalGuidanceState, CanvasRenderableEntityIdentifier, + CanvasRenderableEntityState, } from 'features/controlLayers/store/types'; import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util'; import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions'; @@ -82,7 +82,7 @@ export const createNewCanvasEntityFromImage = (arg: { type: CanvasEntityType | 'regional_guidance_with_reference_image'; dispatch: AppDispatch; getState: () => RootState; - overrides?: Partial>; + overrides?: Partial>; }) => { const { type, imageDTO, dispatch, getState, overrides: _overrides } = arg; const state = getState(); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts index b5cdfd39ee1..755ff4ea382 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts @@ -63,14 +63,6 @@ describe(validateConnectionTypes.name, () => { }); describe('special cases', () => { - it('should reject a COLLECTION input to a COLLECTION input', () => { - const r = validateConnectionTypes( - { name: 'CollectionField', cardinality: 'COLLECTION', batch: false }, - { name: 'CollectionField', cardinality: 'COLLECTION', batch: false } - ); - expect(r).toBe(false); - }); - it('should accept equal types', () => { const r = validateConnectionTypes( { name: 'IntegerField', cardinality: 'SINGLE', batch: false }, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts index 44f632490a4..835bf83af03 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -8,13 +8,6 @@ import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'fe * @returns True if the connection is valid, false otherwise. */ export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => { - // TODO: There's a bug with Collect -> Iterate nodes: - // https://github.com/invoke-ai/InvokeAI/issues/3956 - // Once this is resolved, we can remove this check. - if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { - return false; - } - if (areTypesEqual(sourceType, targetType)) { return true; } diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py index a1e2271ea9d..b6775b71ae3 100644 --- a/invokeai/version/invokeai_version.py +++ b/invokeai/version/invokeai_version.py @@ -1 +1 @@ -__version__ = "5.6.0" +__version__ = "5.6.1rc1" diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py index b0358c08bad..0a4ce775384 100644 --- a/tests/test_node_graph.py +++ b/tests/test_node_graph.py @@ -9,7 +9,9 @@ invocation, invocation_output, ) +from invokeai.app.invocations.math import AddInvocation from invokeai.app.invocations.primitives import ( + ColorInvocation, FloatCollectionInvocation, FloatInvocation, IntegerInvocation, @@ -689,9 +691,6 @@ def test_any_accepts_any(): def test_iterate_accepts_collection(): - """We need to update the validation for Collect -> Iterate to traverse to the Iterate - node's output and compare that against the item type of the Collect node's collection. Until - then, Collect nodes may not output into Iterate nodes.""" g = Graph() n1 = IntegerInvocation(id="1", value=1) n2 = IntegerInvocation(id="2", value=2) @@ -706,9 +705,36 @@ def test_iterate_accepts_collection(): e3 = create_edge(n3.id, "collection", n4.id, "collection") g.add_edge(e1) g.add_edge(e2) - # Once we fix the validation logic as described, this should should not raise an error - with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"): - g.add_edge(e3) + g.add_edge(e3) + + +def test_iterate_validates_collection_inputs_against_iterator_outputs(): + g = Graph() + n1 = IntegerInvocation(id="1", value=1) + n2 = IntegerInvocation(id="2", value=2) + n3 = CollectInvocation(id="3") + n4 = IterateInvocation(id="4") + n5 = AddInvocation(id="5") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + g.add_node(n5) + e1 = create_edge(n1.id, "value", n3.id, "item") + e2 = create_edge(n2.id, "value", n3.id, "item") + e3 = create_edge(n3.id, "collection", n4.id, "collection") + e4 = create_edge(n4.id, "item", n5.id, "a") + g.add_edge(e1) + g.add_edge(e2) + g.add_edge(e3) + # Not throwing on this line indicates the collector's input types validated successfully against the iterator's output types + g.add_edge(e4) + with pytest.raises(InvalidEdgeError, match="Iterator collection type must match all iterator output types"): + # Connect iterator to a node with a different type than the collector inputs which is not allowed + n6 = ColorInvocation(id="6") + g.add_node(n6) + e5 = create_edge(n4.id, "item", n6.id, "color") + g.add_edge(e5) def test_graph_can_generate_schema():