diff --git a/src/plopp/backends/pythreejs/scatter3d.py b/src/plopp/backends/pythreejs/scatter3d.py index af758fc09..30398dbbe 100644 --- a/src/plopp/backends/pythreejs/scatter3d.py +++ b/src/plopp/backends/pythreejs/scatter3d.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) import uuid -from typing import Literal +from typing import Any, Literal import numpy as np import scipp as sc @@ -63,8 +63,6 @@ def __init__( opacity: float = 1, pixel_size: sc.Variable | float | None = None, ): - import pythreejs as p3 - check_ndim(data, ndim=1, origin='Scatter3d') self.uid = uid if uid is not None else uuid.uuid4().hex self._canvas = canvas @@ -73,6 +71,10 @@ def __init__( self._x = x self._y = y self._z = z + self._unique_color = to_rgb(f'C{artist_number}' if color is None else color) + self._opacity = opacity + self._new_points = None + self._new_colors = None # TODO: remove pixel_size in the next release self._size = size if pixel_size is None else pixel_size @@ -88,42 +90,64 @@ def __init__( dtype=float, unit=self._data.coords[x].unit ).value + self.points = self._make_point_cloud() + self._canvas.add(self.points) + if self._colormapper is not None: self._colormapper.add_artist(self.uid, self) - colors = self._colormapper.rgba(self.data)[..., :3].astype('float32') - else: - colors = np.broadcast_to( - np.array(to_rgb(f'C{artist_number}' if color is None else color)), - (self._data.coords[self._x].shape[0], 3), - ).astype('float32') - self.geometry = p3.BufferGeometry( + def _make_point_cloud(self) -> None: + """ + Create the point cloud geometry and material. + """ + import pythreejs as p3 + + self._backup_coords() + + geometry = p3.BufferGeometry( attributes={ 'position': p3.BufferAttribute( - array=np.array( + array=np.stack( [ self._data.coords[self._x].values.astype('float32'), self._data.coords[self._y].values.astype('float32'), self._data.coords[self._z].values.astype('float32'), - ] - ).T + ], + axis=1, + ) + ), + 'color': p3.BufferAttribute( + array=np.broadcast_to( + np.array(self._unique_color), + (self._data.coords[self._x].shape[0], 3), + ).astype('float32') ), - 'color': p3.BufferAttribute(array=colors), } ) + self._new_positions = None # TODO: a device pixel_ratio should probably be read from a config file pixel_ratio = 1.0 # Note that an additional factor of 2.5 (obtained from trial and error) seems to # be required to get the sizes right in the scene. - self.material = p3.PointsMaterial( + material = p3.PointsMaterial( vertexColors='VertexColors', size=2.5 * self._size * pixel_ratio, transparent=True, - opacity=opacity, + opacity=self._opacity, + depthTest=self._opacity > 0.5, ) - self.points = p3.Points(geometry=self.geometry, material=self.material) - self._canvas.add(self.points) + return p3.Points(geometry=geometry, material=material) + + def _backup_coords(self) -> None: + """ + Backup the current coordinates to be able to detect changes. + """ + self._old_coords = { + self._x: self._data.coords[self._x], + self._y: self._data.coords[self._y], + self._z: self._data.coords[self._z], + } def notify_artist(self, message: str) -> None: """ @@ -135,15 +159,27 @@ def notify_artist(self, message: str) -> None: message: The message from the colormapper. """ - self._update_colors() + self._new_colors = self._colormapper.rgba(self.data)[..., :3].astype('float32') + self._finalize_update() - def _update_colors(self): + def _update_positions(self) -> None: """ - Set the point cloud's rgba colors: + Update the point cloud's positions from the data. """ - self.geometry.attributes["color"].array = self._colormapper.rgba(self.data)[ - ..., :3 - ].astype('float32') + if all( + sc.identical(self._old_coords[dim], self._data.coords[dim]) + for dim in [self._x, self._y, self._z] + ): + return + self._backup_coords() + return np.stack( + [ + self._data.coords[self._x].values.astype('float32'), + self._data.coords[self._y].values.astype('float32'), + self._data.coords[self._z].values.astype('float32'), + ], + axis=1, + ) def update(self, new_values): """ @@ -155,19 +191,89 @@ def update(self, new_values): New data to update the point cloud values from. """ check_ndim(new_values, ndim=1, origin='Scatter3d') + old_shape = self._data.shape self._data = new_values - if self._colormapper is not None: - self._update_colors() + + if self._data.shape != old_shape: + self._new_points = self._make_point_cloud() + else: + self._new_points = None + self._new_positions = self._update_positions() + + if self._colormapper is None: + self._finalize_update() + + def _finalize_update(self) -> None: + """ + Finalize the update of the point cloud. + This is called either at the end of the position update if there is no + colormapper, and after the colors are updated in the case of a colormapper. + We want to wait for both to be ready before updating the geometry. + """ + # We use the hold context manager to avoid multiple re-draws of the scene and + # thus prevent flickering. + with self._canvas.renderer.hold(): + if self._new_points is not None: + self._canvas.remove(self.points) + self.points = self._new_points + if self._new_positions is not None: + self.position = self._new_positions + self._new_positions = None + if self._new_colors is not None: + self.color = self._new_colors + self._new_colors = None + # For some reason, adding the points to the scene before updating the colors + # still shows the old colors for a brief moment, even if hold() is active. + if self._new_points is not None: + self._new_points = None + self._canvas.add(self.points) + + @property + def position(self) -> np.ndarray: + """ + The scatter points positions as a (N, 3) numpy array. + """ + return self.geometry.attributes['position'].array + + @position.setter + def position(self, val: np.ndarray): + self.geometry.attributes['position'].array = val + + @property + def color(self) -> np.ndarray: + """ + The scatter points colors as a (N, 3) numpy array. + """ + return self.geometry.attributes['color'].array + + @color.setter + def color(self, val: np.ndarray): + self.geometry.attributes['color'].array = val + + @property + def geometry(self) -> Any: + """ + The scatter points geometry. + """ + return self.points.geometry + + @property + def material(self) -> Any: + """ + The scatter points material. + """ + return self.points.material @property def opacity(self) -> float: """ The scatter points opacity. """ - return self.material.opacity + return self._opacity @opacity.setter def opacity(self, val: float): + self._opacity = val self.material.opacity = val self.material.depthTest = val > 0.5 diff --git a/src/plopp/graphics/graphicalview.py b/src/plopp/graphics/graphicalview.py index 4b28c6b1b..8c5e5f299 100644 --- a/src/plopp/graphics/graphicalview.py +++ b/src/plopp/graphics/graphicalview.py @@ -220,6 +220,8 @@ def update(self, *args, **kwargs) -> None: if self._autoscale: self.fit_to_data() + elif self.colormapper is not None: + self.colormapper.notify_artists() self.canvas.draw() diff --git a/src/plopp/widgets/clip3d.py b/src/plopp/widgets/clip3d.py index cb414ab3a..00626e31b 100644 --- a/src/plopp/widgets/clip3d.py +++ b/src/plopp/widgets/clip3d.py @@ -28,10 +28,6 @@ def _xor(x: list[sc.Variable]) -> sc.Variable: } -def select(da: sc.DataArray, s: tuple[str, sc.Variable]) -> sc.DataArray: - return da[s] - - class Clip3dTool(ipw.HBox): """ A tool that provides a slider to extract a slab of points in a three-dimensional @@ -191,6 +187,16 @@ class ClippingPlanes(ipw.HBox): """ A widget to make clipping planes for spatial cutting (see :class:`Clip3dTool`) to make spatial cuts in the X, Y, and Z directions on a three-dimensional scatter plot. + The widget provides buttons to add/remove cuts, toggle the visibility of the cuts, + and set the operation to combine multiple cuts (OR, AND, XOR). The opacity of the + original point clouds is reduced when at least one cut is active, to provide + context. + + The selection from all cuts are combined to either create or update a + second point cloud which is included in the scene. + When the position/range of a cut is changed, only the outlines of + the cuts are moved in real time, which is cheap. The actual point cloud gets + updated less frequently using a debounce mechanism. .. versionadded:: 24.04.0 @@ -228,7 +234,7 @@ def __init__(self, fig: BaseFig): self.tabs = ipw.Tab(layout={'width': '550px'}) self._original_nodes = list(self._view.graph_nodes.values()) - self._nodes = {} + # self._nodes = {} self.add_cut_label = ipw.Label('Add cut:') layout = {'width': '45px', 'padding': '0px 0px 0px 0px'} @@ -295,6 +301,15 @@ def __init__(self, fig: BaseFig): ) self.delete_cut.on_click(self._remove_cut) + self._nodes = {} + self._cut_info_node = Node(self._get_visible_cuts) + for n in self._original_nodes: + self._nodes[n.id] = Node( + self._select_subset, da=n, cuts=self._cut_info_node + ) + self._nodes[n.id].add_view(self._view) + self.update_state() + super().__init__( [ self.tabs, @@ -354,6 +369,10 @@ def update_controls(self): self.opacity.disabled = not at_least_one_cut opacity = self.opacity.value if at_least_one_cut else 1.0 self._set_opacity({'new': opacity}) + # if not at_least_one_cut: + for n in self._original_nodes: + nid = self._nodes[n.id].id + self._view.artists[nid].visible = at_least_one_cut def _set_opacity(self, change: dict[str, Any]): """ @@ -382,42 +401,34 @@ def change_operation(self, change: dict[str, Any]): self._operation = change['new'].lower() self.update_state() - def update_state(self): + def _get_visible_cuts(self) -> list[Clip3dTool]: """ - Update the state, combining all the active cuts, using the selected binary - operation. The resulting selection is then used to either create or update a - second point cloud which is included in the scene. - The original point cloud is then set to be semi-transparent. - When the position/range of a cut is changed, this function is called via a - debounce mechanism to avoid updating the cloud too often. Only the outlines of - the cuts are moved in real time, which is cheap. + Return the list of visible cuts. """ - for nodes in self._nodes.values(): - self._view.remove(nodes['slice'].id) - nodes['slice'].remove() - self._nodes.clear() + return [cut for cut in self.cuts if cut.visible] - visible_cuts = [cut for cut in self.cuts if cut.visible] - if not visible_cuts: - return + def _select_subset(self, da: sc.DataArray, cuts: list[Clip3dTool]) -> sc.DataArray: + """ + Return the subset of the data array selected by the cuts, combined using the + selected operation. + """ + selections = [] + npoints = 0 + for cut in cuts: + xmin, xmax = cut.range + selection = (da.coords[cut.dim] >= xmin) & (da.coords[cut.dim] < xmax) + npoints += selection.sum().value + selections.append(selection) + # If no points are selected, return a dummy selection to avoid issues with + # empty selections. + if npoints == 0: + return da[0:1] + sel = OPERATIONS[self._operation](selections) + return da[sel] - for n in self._original_nodes: - da = n.request_data() - selections = [] - for cut in visible_cuts: - xmin, xmax = cut.range - selections.append( - (da.coords[cut.dim] >= xmin) & (da.coords[cut.dim] < xmax) - ) - selection = OPERATIONS[self._operation](selections) - if selection.sum().value > 0: - if n.id not in self._nodes: - select_node = Node(selection) - self._nodes[n.id] = { - 'select': select_node, - 'slice': Node(lambda da, s: da[s], da=n, s=select_node), - } - self._nodes[n.id]['slice'].add_view(self._view) - else: - self._nodes[n.id]['select'].func = lambda: selection # noqa: B023 - self._nodes[n.id]['select'].notify_children("") + def update_state(self): + """ + Update the state of the cuts in the figure by triggering the node that + provides the list of visible cuts. + """ + self._cut_info_node.notify_children("") diff --git a/tests/backends/pythreejs/pythreejs_scatter3d_test.py b/tests/backends/pythreejs/pythreejs_scatter3d_test.py index e3ad6a202..257c49752 100644 --- a/tests/backends/pythreejs/pythreejs_scatter3d_test.py +++ b/tests/backends/pythreejs/pythreejs_scatter3d_test.py @@ -8,15 +8,17 @@ from plopp.backends.pythreejs.canvas import Canvas from plopp.backends.pythreejs.scatter3d import Scatter3d from plopp.data.testing import scatter +from plopp.graphics import ColorMapper -def test_creation(): +@pytest.mark.parametrize("with_cbar", [False, True]) +def test_creation(with_cbar): da = scatter() - scat = Scatter3d(canvas=Canvas(), data=da, x='x', y='y', z='z') + canvas = Canvas() + cmapper = ColorMapper(canvas=canvas) if with_cbar else None + scat = Scatter3d(canvas=canvas, data=da, x='x', y='y', z='z', colormapper=cmapper) assert sc.identical(scat._data, da) - assert np.allclose( - scat.geometry.attributes['position'].array, da.coords['position'].values - ) + assert np.allclose(scat.position, da.coords['position'].values) def test_update(): @@ -124,3 +126,38 @@ def test_update_raises_when_data_is_not_1d(): sc.DimensionError, match='Scatter3d only accepts data with 1 dimension' ): scat.update(da2d) + + +@pytest.mark.parametrize("with_cbar", [False, True]) +def test_update_with_new_positions(with_cbar): + canvas = Canvas() + cmapper = ColorMapper(canvas=canvas) if with_cbar else None + da = scatter(npoints=500, seed=10) + scat = Scatter3d(canvas=canvas, data=da, x='x', y='y', z='z', colormapper=cmapper) + assert scat.position.shape[0] == 500 + assert scat.color.shape[0] == 500 + new = scatter(npoints=500, seed=20) + new.data = da.data # Keep the same data values + scat.update(new) + if with_cbar: + scat.notify_artist("") # To update the colors + assert scat.position.shape[0] == 500 + assert scat.color.shape[0] == 500 + assert sc.identical(scat._data, new) + + +@pytest.mark.parametrize("with_cbar", [False, True]) +def test_update_with_different_number_of_points(with_cbar): + canvas = Canvas() + cmapper = ColorMapper(canvas=canvas) if with_cbar else None + da = scatter(npoints=500) + scat = Scatter3d(canvas=canvas, data=da, x='x', y='y', z='z', colormapper=cmapper) + assert scat.position.shape[0] == 500 + assert scat.color.shape[0] == 500 + new = scatter(npoints=200) + scat.update(new) + if with_cbar: + scat.notify_artist("") # To update the colors + assert scat.position.shape[0] == 200 + assert scat.color.shape[0] == 200 + assert sc.identical(scat._data, new)