diff --git a/examples/add_images_with_axis_labels.py b/examples/add_images_with_axis_labels.py new file mode 100644 index 00000000000..d18bd814cd0 --- /dev/null +++ b/examples/add_images_with_axis_labels.py @@ -0,0 +1,31 @@ +import numpy as np + +import napari + +viewer = napari.Viewer() +viewer.axes.visible = True + +print(f'{viewer.axis_labels=}') # -> () + +image = viewer.add_image( + np.ones((5, 3, 2)), + axis_labels=("time", "y", "x"), + colormap='red', +) +print(f'{viewer.axis_labels=}') # -> ("time", "y", "x") + +image = viewer.add_image( + np.ones((4, 3, 2)), + axis_labels=("z", "y", "x"), + colormap='green', +) +print(f'{viewer.axis_labels=}') # -> ("z", "time", "y", "x") + +image = viewer.add_image( + np.ones((6, 4, 3, 2)), + axis_labels=["freq", "z", "y", "x"], + colormap='blue', +) +print(f'{viewer.axis_labels=}') # -> ("freq", "z", "time", "y", "x") + +#napari.run() diff --git a/napari/components/layerlist.py b/napari/components/layerlist.py index ef777e4fa19..6e2ae36e7c5 100644 --- a/napari/components/layerlist.py +++ b/napari/components/layerlist.py @@ -1,8 +1,8 @@ import itertools import warnings from collections import namedtuple -from functools import cached_property -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union +from functools import cached_property, reduce +from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple, Union import numpy as np @@ -474,3 +474,8 @@ def save( return [] return save_layers(path, layers, plugin=plugin, _writer=_writer) + + @property + def axis_labels(self) -> Set[Optional[str]]: + layer_labels = (layer.axis_labels for layer in self.layers) + return reduce(set.union, layer_labels, set()) diff --git a/napari/components/viewer_model.py b/napari/components/viewer_model.py index d385eb27015..c95de18f276 100644 --- a/napari/components/viewer_model.py +++ b/napari/components/viewer_model.py @@ -129,8 +129,6 @@ class ViewerModel(KeymapProvider, MousemapProvider, EventedModel): Order in which dimensions are displayed where the last two or last three dimensions correspond to row x column or plane x row x column if ndisplay is 2 or 3. - axis_labels : list of str - Dimension names. Attributes ---------- @@ -140,6 +138,8 @@ class ViewerModel(KeymapProvider, MousemapProvider, EventedModel): List of contained layers. dims : Dimensions Contains axes, indices, dimensions and sliders. + axis_labels : tuple of strings + Dimension names. """ # Using allow_mutation=False means these attributes aren't settable and don't @@ -167,6 +167,8 @@ class ViewerModel(KeymapProvider, MousemapProvider, EventedModel): # different events systems mouse_over_canvas: bool = False + axis_labels: Tuple[str, ...] = () + # Need to use default factory because slicer is not copyable which # is required for default values. _layer_slicer: _LayerSlicer = PrivateAttr(default_factory=_LayerSlicer) @@ -549,6 +551,14 @@ def _on_add_layer(self, event): # Update dims and grid model self._on_layers_change() self._on_grid_change() + + # Prepend new axis labels from layer + new_labels = [] + for axis in layer.axis_labels: + if axis and axis not in self.axis_labels: + new_labels.append(axis) + self.axis_labels = tuple(new_labels) + self.axis_labels + # Slice current layer based on dims self._update_layers(layers=[layer]) @@ -613,6 +623,13 @@ def _on_remove_layer(self, event): """ layer = event.value + all_labels = self.layers.axis_labels + new_labels = list(self.axis_labels) + for axis in layer.axis_labels: + if axis not in all_labels: + new_labels.remove(axis) + self.axis_labels = new_labels + # Disconnect all connections from layer disconnect_events(layer.events, self) disconnect_events(layer.events, self.layers) @@ -671,6 +688,7 @@ def add_image( plane=None, experimental_clipping_planes=None, custom_interpolation_kernel_2d=None, + axis_labels: Optional[Sequence[str]] = None, ) -> Union[Image, List[Image]]: """Add an image layer to the layer list. @@ -848,6 +866,7 @@ def add_image( 'plane': plane, 'experimental_clipping_planes': experimental_clipping_planes, 'custom_interpolation_kernel_2d': custom_interpolation_kernel_2d, + 'axis_labels': axis_labels, } # these arguments are *already* iterables in the single-channel case. @@ -861,6 +880,7 @@ def add_image( 'metadata', 'experimental_clipping_planes', 'custom_interpolation_kernel_2d', + 'axis_labels', } if channel_axis is None: diff --git a/napari/layers/base/base.py b/napari/layers/base/base.py index cd3d36dc9d7..2a1ec51501e 100644 --- a/napari/layers/base/base.py +++ b/napari/layers/base/base.py @@ -7,7 +7,7 @@ from collections import defaultdict, namedtuple from contextlib import contextmanager from functools import cached_property -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import magicgui as mgui import numpy as np @@ -115,6 +115,12 @@ class Layer(KeymapProvider, MousemapProvider, ABC): Whether the data is multiscale or not. Multiscale data is represented by a list of data objects and should go from largest to smallest. + axis_labels : optional sequence of strings + The names of the layer's axes. + If the value of this is none, then a tuple of decreasing negative + numbers is generated so that axis correspondence can be like numpy + broadcasting (i.e. allow prepending but not appending of singular axes) + to match array shapes. Attributes ---------- @@ -204,6 +210,8 @@ class Layer(KeymapProvider, MousemapProvider, ABC): depends on the current zoom level. source : Source source of the layer (such as a plugin or widget) + axis_labels : tuple of strings + The names of the layer's axes. Notes ----- @@ -253,6 +261,7 @@ def __init__( cache=True, # this should move to future "data source" object. experimental_clipping_planes=None, mode='pan_zoom', + axis_labels: Optional[Sequence[str]] = None, ) -> None: super().__init__() @@ -292,6 +301,10 @@ def __init__( self._ndim = ndim + self._axis_labels: Optional[ + Tuple[str, ...] + ] = self._coerce_axis_labels(axis_labels) + self._slice_input = _SliceInput( ndisplay=2, point=(0,) * ndim, @@ -407,6 +420,25 @@ def __init__( # until we figure out nested evented objects self._overlays.events.connect(self.events._overlays) + @property + def axis_labels(self) -> Tuple[str, ...]: + return self._axis_labels + + @axis_labels.setter + def axis_labels(self, axis_labels: Optional[Sequence[str]]) -> None: + self._axis_labels = self._coerce_axis_labels(axis_labels) + + def _coerce_axis_labels( + self, axis_labels: Optional[Sequence[str]] + ) -> Tuple[Optional[str], ...]: + if axis_labels is None: + return (None,) * self.ndim + if len(axis_labels) != self.ndim: + raise ValueError( + "The number of axis labels ({len(axis_labels)}) must match the number of dimensions ({self.ndim})." + ) + return tuple(axis_labels) + def __str__(self): """Return self.name.""" return self.name @@ -843,6 +875,7 @@ def _get_base_state(self): 'experimental_clipping_planes': [ plane.dict() for plane in self.experimental_clipping_planes ], + 'axis_labels': self.axis_labels, } return base_dict diff --git a/napari/layers/image/image.py b/napari/layers/image/image.py index c6c2ae36767..a29a1427dd4 100644 --- a/napari/layers/image/image.py +++ b/napari/layers/image/image.py @@ -5,7 +5,7 @@ import types import warnings from contextlib import nullcontext -from typing import TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union import numpy as np from scipy import ndimage as ndi @@ -256,6 +256,7 @@ def __init__( plane=None, experimental_clipping_planes=None, custom_interpolation_kernel_2d=None, + axis_labels: Optional[Sequence[str]] = None, ) -> None: if name is None and data is not None: name = magic_name(data) @@ -308,6 +309,7 @@ def __init__( multiscale=multiscale, cache=cache, experimental_clipping_planes=experimental_clipping_planes, + axis_labels=axis_labels, ) self.events.add( diff --git a/napari/view_layers.py b/napari/view_layers.py index c3994b3948c..330887b3ec3 100644 --- a/napari/view_layers.py +++ b/napari/view_layers.py @@ -94,7 +94,11 @@ def _merge_layer_viewer_sigs_docs(func): # merge the signatures of Viewer and viewer.add_* func.__signature__ = _combine_signatures( - add_method, Viewer, return_annotation=Viewer, exclude=('self',) + # hack to get around duplicate param name + add_method, + Viewer, + return_annotation=Viewer, + exclude=('self', 'axis_labels'), ) # merge the __annotations__