|
| 1 | +"""Lineplot of dissimilarity over time |
| 2 | +
|
| 3 | +See demo_meg_mne for an example. |
| 4 | +""" |
| 5 | +# pylint: disable=too-many-statements,unused-argument,too-many-locals |
| 6 | +from __future__ import annotations |
| 7 | +from typing import TYPE_CHECKING, Tuple, List, Optional, Dict |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import numpy as np |
| 10 | +if TYPE_CHECKING: |
| 11 | + from rsatoolbox.rdm.rdms import RDMs |
| 12 | + from matplotlib.axes._axes import Axes |
| 13 | + from matplotlib.figure import Figure |
| 14 | + from numpy.typing import NDArray |
| 15 | + |
| 16 | + |
| 17 | +def plot_timecourse( |
| 18 | + rdms_data: RDMs, |
| 19 | + descriptor: str, |
| 20 | + n_t_display:int = 20, |
| 21 | + fig_width: Optional[int] = None, |
| 22 | + timecourse_plot_rel_height: Optional[int] = None, |
| 23 | + time_formatted: Optional[List[str]] = None, |
| 24 | + colored_conditions: Optional[list] = None, |
| 25 | + plot_individual_dissimilarities: Optional[bool] = None, |
| 26 | + ) -> Tuple[Figure, List[Axes]]: |
| 27 | + """ plots the RDM movie for a given descriptor |
| 28 | +
|
| 29 | + Args: |
| 30 | + rdms_data (rsatoolbox.rdm.RDMs): rdm movie |
| 31 | + descriptor (str): name of the descriptor that created the rdm movie |
| 32 | + n_t_display (int, optional): number of RDM time points to display. Defaults to 20. |
| 33 | + fig_width (int, optional): width of the figure (in inches). Defaults to None. |
| 34 | + timecourse_plot_rel_height (int, optional): height of the timecourse plot (relative to |
| 35 | + the rdm movie row). |
| 36 | + time_formatted (List[str], optional): time points formatted as strings. |
| 37 | + Defaults to None (i.e., rdms_data.time_descriptors['time'] is considered to |
| 38 | + be in seconds) |
| 39 | + colored_condiitons (list, optional): vector of pattern condition names to dissimilarities |
| 40 | + according to a categorical model on colored_conditions Defaults to None. |
| 41 | + plot_individual_dissimilarities (bool, optional): whether to plot the individual |
| 42 | + dissimilarities. Defaults to None (i.e., False if colored_conditions is not |
| 43 | + None, True otherwise). |
| 44 | +
|
| 45 | + Returns: |
| 46 | + Tuple[matplotlib.figure.Figure, npt.ArrayLike, collections.defaultdict]: |
| 47 | +
|
| 48 | + Tuple of |
| 49 | + - Handle to created figure |
| 50 | + - Subplot axis handles from plt.subplots. |
| 51 | + """ |
| 52 | + # create labels |
| 53 | + time = rdms_data.rdm_descriptors['time'] |
| 54 | + unique_time = np.unique(time) |
| 55 | + time_formatted = time_formatted or [f'{np.round(x*1000,2):0.0f} ms' for x in unique_time] |
| 56 | + |
| 57 | + n_dissimilarity_elements = rdms_data.dissimilarities.shape[1] |
| 58 | + |
| 59 | + # color mapping from colored conditions |
| 60 | + plot_individual_dissimilarities, color_index = _map_colors( |
| 61 | + colored_conditions, plot_individual_dissimilarities, rdms_data) |
| 62 | + |
| 63 | + colors = plt.get_cmap('turbo')(np.linspace(0, 1, len(color_index)+1)) |
| 64 | + |
| 65 | + # how many rdms to display |
| 66 | + n_times = len(unique_time) |
| 67 | + t_display_idx = (np.round(np.linspace(0, n_times-1, min(n_times, n_t_display)))).astype(int) |
| 68 | + t_display_idx = np.unique(t_display_idx) |
| 69 | + n_t_display = len(t_display_idx) |
| 70 | + |
| 71 | + # auto determine relative sizes of axis |
| 72 | + timecourse_plot_rel_height = timecourse_plot_rel_height or n_t_display // 3 |
| 73 | + base_size = 40 / n_t_display if fig_width is None else fig_width / n_t_display |
| 74 | + |
| 75 | + # figure layout |
| 76 | + fig = plt.figure( |
| 77 | + constrained_layout=True, |
| 78 | + figsize=(base_size*n_t_display,base_size*timecourse_plot_rel_height) |
| 79 | + ) |
| 80 | + gs = fig.add_gridspec(timecourse_plot_rel_height+1, n_t_display) |
| 81 | + tc_ax = fig.add_subplot(gs[:-1,:]) |
| 82 | + rdm_axes = [fig.add_subplot(gs[-1,i]) for i in range(n_t_display)] |
| 83 | + |
| 84 | + # plot dissimilarity timecourses |
| 85 | + dissimilarities_mean = np.zeros((rdms_data.dissimilarities.shape[1], len(unique_time))) |
| 86 | + for i, t in enumerate(unique_time): |
| 87 | + dissimilarities_mean[:, i] = np.mean(rdms_data.dissimilarities[t == time, :], axis=0) |
| 88 | + |
| 89 | + def _plot_mean_dissimilarities(labels=False): |
| 90 | + for i, (pairwise_name, idx) in enumerate(color_index.items()): |
| 91 | + mn = np.mean(dissimilarities_mean[idx, :],axis=0) |
| 92 | + n = np.sqrt(dissimilarities_mean.shape[0]) |
| 93 | + # se is over dissimilarities, not over subjects |
| 94 | + se = np.std(dissimilarities_mean[idx, :],axis=0)/n |
| 95 | + tc_ax.fill_between(unique_time, mn-se, mn+se, color=colors[i], alpha=.3) |
| 96 | + label = pairwise_name if labels else None |
| 97 | + tc_ax.plot(unique_time, mn, color=colors[i], linewidth=2, label=label) |
| 98 | + |
| 99 | + def _plot_individual_dissimilarities(): |
| 100 | + for i, (_, idx) in enumerate(color_index.items()): |
| 101 | + a = max(1/255., 1/n_dissimilarity_elements) |
| 102 | + tc_ax.plot(unique_time, dissimilarities_mean[idx, :].T, color=colors[i], alpha=a) |
| 103 | + |
| 104 | + if plot_individual_dissimilarities: |
| 105 | + if colored_conditions is not None: |
| 106 | + _plot_mean_dissimilarities() |
| 107 | + yl = tc_ax.get_ylim() |
| 108 | + _plot_individual_dissimilarities() |
| 109 | + tc_ax.set_ylim(yl) |
| 110 | + else: |
| 111 | + _plot_individual_dissimilarities() |
| 112 | + |
| 113 | + if colored_conditions is not None: |
| 114 | + _plot_mean_dissimilarities(True) |
| 115 | + |
| 116 | + yl = tc_ax.get_ylim() |
| 117 | + for t in unique_time[t_display_idx]: |
| 118 | + tc_ax.plot([t,t], yl, linestyle=':', color='b', alpha=0.3) |
| 119 | + tc_ax.set_ylabel(f'Dissimilarity\n({rdms_data.dissimilarity_measure})') |
| 120 | + tc_ax.set_xticks(unique_time) |
| 121 | + tc_ax.set_xticklabels([ |
| 122 | + time_formatted[idx] if idx in t_display_idx else '' for idx in range(len(unique_time)) |
| 123 | + ]) |
| 124 | + dt = np.diff(unique_time[t_display_idx])[0] |
| 125 | + tc_ax.set_xlim(unique_time[t_display_idx[0]]-dt/2, unique_time[t_display_idx[-1]]+dt/2) |
| 126 | + |
| 127 | + tc_ax.legend() |
| 128 | + |
| 129 | + # display (selected) rdms |
| 130 | + vmax = np.std(rdms_data.dissimilarities) * 2 |
| 131 | + for i, (tidx, a) in enumerate(zip(t_display_idx, rdm_axes)): |
| 132 | + mean_dissim = np.mean(rdms_data.subset('time', unique_time[tidx]).get_matrices(),axis=0) |
| 133 | + a.imshow(mean_dissim, vmin=0, vmax=vmax) |
| 134 | + a.set_title(f'{np.round(unique_time[tidx]*1000,2):0.0f} ms') |
| 135 | + a.set_yticklabels([]) |
| 136 | + a.set_yticks([]) |
| 137 | + a.set_xticklabels([]) |
| 138 | + a.set_xticks([]) |
| 139 | + |
| 140 | + return fig, [tc_ax] + rdm_axes |
| 141 | + |
| 142 | + |
| 143 | +def unsquareform(a: NDArray) -> NDArray: |
| 144 | + """Helper function; convert squareform to vector |
| 145 | + """ |
| 146 | + return a[np.nonzero(np.triu(a, k=1))] |
| 147 | + |
| 148 | + |
| 149 | +def _map_colors( |
| 150 | + colored_conditions: Optional[list], |
| 151 | + plot_individual_dissimilarities: Optional[bool], |
| 152 | + rdms: RDMs |
| 153 | + ) -> Tuple[bool, Dict[str, NDArray]]: |
| 154 | + n_dissimilarity_elements = rdms.dissimilarities.shape[1] |
| 155 | + # color mapping from colored conditions |
| 156 | + if colored_conditions is not None: |
| 157 | + if plot_individual_dissimilarities is None: |
| 158 | + plot_individual_dissimilarities = False |
| 159 | + sf_conds = [[{c1, c2} for c1 in colored_conditions] for c2 in colored_conditions] |
| 160 | + pairwise_conds = unsquareform(np.array(sf_conds)) |
| 161 | + pairwise_conds_unique = np.unique(pairwise_conds) |
| 162 | + color_index = {} |
| 163 | + for x in pairwise_conds_unique: |
| 164 | + if len(list(x))==2: |
| 165 | + key = f'{list(x)[0]} vs {list(x)[1]}' |
| 166 | + else: |
| 167 | + key = f'{list(x)[0]} vs {list(x)[0]}' |
| 168 | + color_index[key] = pairwise_conds==x |
| 169 | + else: |
| 170 | + color_index = {'': np.array([True]*n_dissimilarity_elements)} |
| 171 | + plot_individual_dissimilarities = True |
| 172 | + return plot_individual_dissimilarities, color_index |
0 commit comments