diff --git a/doc/api/visualization.rst b/doc/api/visualization.rst index 280ed51f590..7c2cf53265f 100644 --- a/doc/api/visualization.rst +++ b/doc/api/visualization.rst @@ -68,6 +68,7 @@ Visualization plot_volume_source_estimates plot_vector_source_estimates plot_sparse_source_estimates + plot_stat_cluster plot_tfr_topomap plot_topo_image_epochs plot_topomap diff --git a/doc/changes/dev/13366.newfeature.rst b/doc/changes/dev/13366.newfeature.rst new file mode 100644 index 00000000000..22796aa1fc8 --- /dev/null +++ b/doc/changes/dev/13366.newfeature.rst @@ -0,0 +1 @@ +Add :func:`~mne.viz.plot_stat_cluster` that plots the spatial extent of a cluster on top of a brain by `Shristi Baral`_. \ No newline at end of file diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index f844d9b54e5..f50efbb2de5 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -83,6 +83,7 @@ ) from ._dipole import _check_concat_dipoles, _plot_dipole_3d, _plot_dipole_mri_outlines from .evoked_field import EvokedField +from .ui_events import subscribe from .utils import ( _check_time_unit, _get_cmap, @@ -4301,3 +4302,159 @@ def _get_3d_option(key): else: opt = opt.lower() == "true" return opt + + +def plot_stat_cluster(cluster, src, brain, time="max-extent", color="magenta", width=1): + """Plot the spatial extent of a cluster on top of a brain. + + Parameters + ---------- + cluster : tuple + The cluster to plot. A cluster is a tuple of two elements: + an array of time indices + and an array of vertex indices. + src : SourceSpaces + The source space that was used for the inverse computation. + brain : Brain + The brain figure on which to plot the cluster. + time : float | "interactive" | "max-extent" + The time (in seconds) at which to plot the spatial extent of the cluster. + If set to ``"interactive"`` the time will follow the selected time in the brain + figure. + By default, ``"max-extent"``, the time of maximal spatial extent is chosen. + color : str + A maplotlib-style color specification indicating the color to use when plotting + the spatial extent of the cluster. + width : int + The width of the lines used to draw the outlines. + + Returns + ------- + brain : Brain + The brain figure, now with the cluster plotted on top of it. + """ + # Here due to circular import + from ..label import Label + + # args check + if not isinstance(cluster, tuple): + raise TypeError(f"Tuple expected, got {type(cluster)} instead.") + elif len(cluster) != 2: + raise ValueError( + "A cluster is a tuple of two elements, a list time indices " + "and list of vertex indices." + ) + else: + cluster_time_idx, cluster_vertex_index = cluster + + # A cluster is defined both in space and time. If we want to plot the boundaries of + # the cluster in space, we must choose a specific time for which to show the + # boundaries (as they change over time). + if time == "max-extent": + time_idx, n_vertices = np.unique(cluster_time_idx, return_counts=True) + time_idx = time_idx[np.argmax(n_vertices)] + elif time == "interactive": + time_idx = brain._data["time_idx"] + elif isinstance(time, float): + time_idx = np.searchsorted(brain._times[:-1], time) + else: + raise ValueError( + "Time should be 'max-extent', 'interactive', or floating point" + f" value, got '{time}' instead." + ) + + # Select only the vertex indices at the chosen time + draw_vertex_index = [ + v for v, t in zip(cluster_vertex_index, cluster_time_idx) if t == time_idx + ] + + # Create the anatomical label containing the vertex indices belonging to the + # cluster. A label cannot span both hemispheres. + # So we must filter the vertices based on their hemisphere. + + # The source space object is actually a list of two source spaces, left and right + # hemisphere. + src_lh, src_rh = src + + # Split the vertices based on the hemisphere in which they are located. + lh_verts, rh_verts = src_lh["vertno"], src_rh["vertno"] + n_lh_verts = len(lh_verts) + draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts] + draw_rh_verts = [ + rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts + ] + + # Vertices in a label must be unique and in increasing order + draw_lh_verts = np.unique(draw_lh_verts) + draw_rh_verts = np.unique(draw_rh_verts) + + # We are now ready to create the anatomical label objects + cluster_index = 0 + for label in brain.labels["lh"] + brain.labels["rh"]: + if label.name.startswith("cluster-"): + try: + cluster_index = max(cluster_index, int(label.name.split("-", 1)[1])) + except ValueError: + pass + lh_label = Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}") + rh_label = Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}") + + # Transform vertex indices into proper vertex numbers. + # Not every vertex in the original high-resolution brain mesh is a + # source point in the source estimate. Do draw nice smooth curves, we need to + # interpolate the vertex indices. + + # Here, we interpolate the vertices in each label to the full resolution mesh + if len(lh_label) > 0: + lh_label = lh_label.smooth( + smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir + ) + brain.add_label(lh_label, borders=width, color=color) + if len(rh_label) > 0: + rh_label = rh_label.smooth( + smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir + ) + brain.add_label(rh_label, borders=width, color=color) + + def on_time_change(event): + time_idx = np.searchsorted(brain._times, event.time) + for hemi in brain._hemis: + mesh = brain._layered_meshes[hemi] + for i, label in enumerate(brain.labels[hemi]): + if label.name == f"cluster-{cluster_index}": + del brain.labels[hemi][i] + mesh.remove_overlay(label.name) + + # Select only the vertex indices at the chosen time + draw_vertex_index = [ + v for v, t in zip(cluster_vertex_index, cluster_time_idx) if t == time_idx + ] + draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts] + draw_rh_verts = [ + rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts + ] + + # Vertices in a label must be unique and in increasing order + draw_lh_verts = np.unique(draw_lh_verts) + draw_rh_verts = np.unique(draw_rh_verts) + lh_label = Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}") + rh_label = Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}") + if len(lh_label) > 0: + lh_label = lh_label.smooth( + smooth=3, + subject=brain._subject, + subjects_dir=brain._subjects_dir, + verbose=False, + ) + brain.add_label(lh_label, borders=width, color=color) + if len(rh_label) > 0: + rh_label = rh_label.smooth( + smooth=3, + subject=brain._subject, + subjects_dir=brain._subjects_dir, + verbose=False, + ) + brain.add_label(rh_label, borders=width, color=color) + + if time == "interactive": + subscribe(brain, "time_change", on_time_change) diff --git a/mne/viz/__init__.pyi b/mne/viz/__init__.pyi index c58ad7d0e54..8a00d5a4f3d 100644 --- a/mne/viz/__init__.pyi +++ b/mne/viz/__init__.pyi @@ -72,6 +72,7 @@ __all__ = [ "plot_source_estimates", "plot_source_spectrogram", "plot_sparse_source_estimates", + "plot_stat_cluster", "plot_tfr_topomap", "plot_topo_image_epochs", "plot_topomap", @@ -97,6 +98,7 @@ from ._3d import ( plot_head_positions, plot_source_estimates, plot_sparse_source_estimates, + plot_stat_cluster, plot_vector_source_estimates, plot_volume_source_estimates, set_3d_options, diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 01d6d5a960d..9c25f22af0e 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -49,6 +49,7 @@ plot_head_positions, plot_source_estimates, plot_sparse_source_estimates, + plot_stat_cluster, snapshot_brain_montage, ) from mne.viz._3d import _get_map_ticks, _linearize_map, _process_clim @@ -1413,3 +1414,57 @@ def test_link_brains(renderer_interactive): with pytest.raises(TypeError, match="type is Brain"): link_brains("foo") link_brains(brain, time=True, camera=True) + + +@testing.requires_testing_data +def test_plot_stat_cluster(renderer_interactive): + """Test plotting clusters on brain in static and interactive mode.""" + sample_src = read_source_spaces(src_fname) + vertices = [s["vertno"] for s in sample_src] + n_time = 5 + n_verts = sum(len(v) for v in vertices) + + # simulate stc data + stc_data = np.zeros(n_verts * n_time) + stc_size = stc_data.size + stc_data[(np.random.rand(stc_size // 20) * stc_size).astype(int)] = ( + np.random.RandomState(0).rand(stc_data.size // 20) + ) + stc_data.shape = (n_verts, n_time) + stc = SourceEstimate(stc_data, vertices, 1, 1) + + # Simulate a cluster + cluster_time_idx = [1, 1, 2, 3] + cluster_vertex_idx = [0, 1, 2, 3] + cluster = (cluster_time_idx, cluster_vertex_idx) + + brain = plot_source_estimates( + stc, + "sample", + background=(1, 1, 0), + subjects_dir=subjects_dir, + colorbar=True, + clim="auto", + ) + # Test for incorrect argument in time + with pytest.raises(ValueError): + plot_stat_cluster(cluster, sample_src, brain, "foo") + + # test for incorrect shape of cluster + with pytest.raises(TypeError): + plot_stat_cluster(([1]), sample_src, brain) + + # test for incorrect data type of cluster + with pytest.raises(TypeError): + plot_stat_cluster([[1, 2, 3], [1, 2, 3]], sample_src, brain) + + # All arguments are correct + plot_stat_cluster(cluster, sample_src, brain) + + # Check that the proper anatomical label has been constructed. + assert len(brain.labels["lh"]) == 1 + assert len(brain.labels["rh"]) == 0 + assert brain.labels["lh"][0].name == "cluster-0" + + brain.close() + del brain diff --git a/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py b/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py index 53e90f78d01..f5315b65689 100644 --- a/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py +++ b/tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py @@ -29,6 +29,7 @@ from mne.epochs import equalize_epoch_counts from mne.minimum_norm import apply_inverse, read_inverse_operator from mne.stats import spatio_temporal_cluster_1samp_test, summarize_clusters_stc +from mne.viz import plot_stat_cluster # %% # Set parameters @@ -142,19 +143,18 @@ # Read the source space we are morphing to src = mne.read_source_spaces(src_fname) fsave_vertices = [s["vertno"] for s in src] -morph_mat = mne.compute_source_morph( +morph = mne.compute_source_morph( src=inverse_operator["src"], subject_to="fsaverage", spacing=fsave_vertices, subjects_dir=subjects_dir, -).morph_mat - -n_vertices_fsave = morph_mat.shape[0] +) +n_vertices_fsave = morph.morph_mat.shape[0] # We have to change the shape for the dot() to work properly X = X.reshape(n_vertices_sample, n_times * n_subjects * 2) print("Morphing data.") -X = morph_mat.dot(X) # morph_mat is a sparse matrix +X = morph.morph_mat.dot(X) # morph_mat is a sparse matrix X = X.reshape(n_vertices_fsave, n_times, n_subjects, 2) # %% @@ -264,3 +264,27 @@ # We could save this via the following: # brain.save_image('clusters.png') + +# %% +# Alternatively, you may wish to observe the spatial and temporal extent of +# single clusters. The code below demonstrates how to plot the cluster +# boundary on top of an existing source estimate. + +difference = morph.apply(condition1 - condition2) +difference_plot = difference.plot( + hemi="both", + views="lateral", + subjects_dir=subjects_dir, + size=(800, 800), + initial_time=0.1, +) + +# Plot one cluster at the time of maximal spatial extent of that cluster +plot_stat_cluster( + good_clusters[2], src, difference_plot, time="max-extent", color="magenta", width=1 +) + +# Plotting the cluster in interactive mode allows scrolling through time +plot_stat_cluster( + good_clusters[2], src, difference_plot, time="interactive", color="magenta", width=1 +)