Skip to content

Commit 0f1e77a

Browse files
authored
Merge pull request #364 from rsagroup/meg-demo-fixes
MEG demo update
2 parents 218b74d + 12d261e commit 0f1e77a

File tree

4 files changed

+441
-321
lines changed

4 files changed

+441
-321
lines changed

demos/demo_meg_mne.ipynb

Lines changed: 261 additions & 321 deletions
Large diffs are not rendered by default.

docs/source/rsatoolbox.vis.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Submodules
1212
rsatoolbox.vis.model_plot
1313
rsatoolbox.vis.rdm_plot
1414
rsatoolbox.vis.scatter_plot
15+
rsatoolbox.vis.timecourse
1516

1617
Module contents
1718
---------------
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
rsatoolbox.vis.timecourse module
2+
================================
3+
4+
.. automodule:: rsatoolbox.vis.timecourse
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

src/rsatoolbox/vis/timecourse.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""Lineplot of dissimilarity over time
2+
3+
See demo_meg_mne for an example.
4+
"""
5+
# pylint: disable=too-many-statements,unused-argument,too-many-locals
6+
from __future__ import annotations
7+
from typing import TYPE_CHECKING, Tuple, List, Optional, Dict
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
if TYPE_CHECKING:
11+
from rsatoolbox.rdm.rdms import RDMs
12+
from matplotlib.axes._axes import Axes
13+
from matplotlib.figure import Figure
14+
from numpy.typing import NDArray
15+
16+
17+
def plot_timecourse(
18+
rdms_data: RDMs,
19+
descriptor: str,
20+
n_t_display:int = 20,
21+
fig_width: Optional[int] = None,
22+
timecourse_plot_rel_height: Optional[int] = None,
23+
time_formatted: Optional[List[str]] = None,
24+
colored_conditions: Optional[list] = None,
25+
plot_individual_dissimilarities: Optional[bool] = None,
26+
) -> Tuple[Figure, List[Axes]]:
27+
""" plots the RDM movie for a given descriptor
28+
29+
Args:
30+
rdms_data (rsatoolbox.rdm.RDMs): rdm movie
31+
descriptor (str): name of the descriptor that created the rdm movie
32+
n_t_display (int, optional): number of RDM time points to display. Defaults to 20.
33+
fig_width (int, optional): width of the figure (in inches). Defaults to None.
34+
timecourse_plot_rel_height (int, optional): height of the timecourse plot (relative to
35+
the rdm movie row).
36+
time_formatted (List[str], optional): time points formatted as strings.
37+
Defaults to None (i.e., rdms_data.time_descriptors['time'] is considered to
38+
be in seconds)
39+
colored_condiitons (list, optional): vector of pattern condition names to dissimilarities
40+
according to a categorical model on colored_conditions Defaults to None.
41+
plot_individual_dissimilarities (bool, optional): whether to plot the individual
42+
dissimilarities. Defaults to None (i.e., False if colored_conditions is not
43+
None, True otherwise).
44+
45+
Returns:
46+
Tuple[matplotlib.figure.Figure, npt.ArrayLike, collections.defaultdict]:
47+
48+
Tuple of
49+
- Handle to created figure
50+
- Subplot axis handles from plt.subplots.
51+
"""
52+
# create labels
53+
time = rdms_data.rdm_descriptors['time']
54+
unique_time = np.unique(time)
55+
time_formatted = time_formatted or [f'{np.round(x*1000,2):0.0f} ms' for x in unique_time]
56+
57+
n_dissimilarity_elements = rdms_data.dissimilarities.shape[1]
58+
59+
# color mapping from colored conditions
60+
plot_individual_dissimilarities, color_index = _map_colors(
61+
colored_conditions, plot_individual_dissimilarities, rdms_data)
62+
63+
colors = plt.get_cmap('turbo')(np.linspace(0, 1, len(color_index)+1))
64+
65+
# how many rdms to display
66+
n_times = len(unique_time)
67+
t_display_idx = (np.round(np.linspace(0, n_times-1, min(n_times, n_t_display)))).astype(int)
68+
t_display_idx = np.unique(t_display_idx)
69+
n_t_display = len(t_display_idx)
70+
71+
# auto determine relative sizes of axis
72+
timecourse_plot_rel_height = timecourse_plot_rel_height or n_t_display // 3
73+
base_size = 40 / n_t_display if fig_width is None else fig_width / n_t_display
74+
75+
# figure layout
76+
fig = plt.figure(
77+
constrained_layout=True,
78+
figsize=(base_size*n_t_display,base_size*timecourse_plot_rel_height)
79+
)
80+
gs = fig.add_gridspec(timecourse_plot_rel_height+1, n_t_display)
81+
tc_ax = fig.add_subplot(gs[:-1,:])
82+
rdm_axes = [fig.add_subplot(gs[-1,i]) for i in range(n_t_display)]
83+
84+
# plot dissimilarity timecourses
85+
dissimilarities_mean = np.zeros((rdms_data.dissimilarities.shape[1], len(unique_time)))
86+
for i, t in enumerate(unique_time):
87+
dissimilarities_mean[:, i] = np.mean(rdms_data.dissimilarities[t == time, :], axis=0)
88+
89+
def _plot_mean_dissimilarities(labels=False):
90+
for i, (pairwise_name, idx) in enumerate(color_index.items()):
91+
mn = np.mean(dissimilarities_mean[idx, :],axis=0)
92+
n = np.sqrt(dissimilarities_mean.shape[0])
93+
# se is over dissimilarities, not over subjects
94+
se = np.std(dissimilarities_mean[idx, :],axis=0)/n
95+
tc_ax.fill_between(unique_time, mn-se, mn+se, color=colors[i], alpha=.3)
96+
label = pairwise_name if labels else None
97+
tc_ax.plot(unique_time, mn, color=colors[i], linewidth=2, label=label)
98+
99+
def _plot_individual_dissimilarities():
100+
for i, (_, idx) in enumerate(color_index.items()):
101+
a = max(1/255., 1/n_dissimilarity_elements)
102+
tc_ax.plot(unique_time, dissimilarities_mean[idx, :].T, color=colors[i], alpha=a)
103+
104+
if plot_individual_dissimilarities:
105+
if colored_conditions is not None:
106+
_plot_mean_dissimilarities()
107+
yl = tc_ax.get_ylim()
108+
_plot_individual_dissimilarities()
109+
tc_ax.set_ylim(yl)
110+
else:
111+
_plot_individual_dissimilarities()
112+
113+
if colored_conditions is not None:
114+
_plot_mean_dissimilarities(True)
115+
116+
yl = tc_ax.get_ylim()
117+
for t in unique_time[t_display_idx]:
118+
tc_ax.plot([t,t], yl, linestyle=':', color='b', alpha=0.3)
119+
tc_ax.set_ylabel(f'Dissimilarity\n({rdms_data.dissimilarity_measure})')
120+
tc_ax.set_xticks(unique_time)
121+
tc_ax.set_xticklabels([
122+
time_formatted[idx] if idx in t_display_idx else '' for idx in range(len(unique_time))
123+
])
124+
dt = np.diff(unique_time[t_display_idx])[0]
125+
tc_ax.set_xlim(unique_time[t_display_idx[0]]-dt/2, unique_time[t_display_idx[-1]]+dt/2)
126+
127+
tc_ax.legend()
128+
129+
# display (selected) rdms
130+
vmax = np.std(rdms_data.dissimilarities) * 2
131+
for i, (tidx, a) in enumerate(zip(t_display_idx, rdm_axes)):
132+
mean_dissim = np.mean(rdms_data.subset('time', unique_time[tidx]).get_matrices(),axis=0)
133+
a.imshow(mean_dissim, vmin=0, vmax=vmax)
134+
a.set_title(f'{np.round(unique_time[tidx]*1000,2):0.0f} ms')
135+
a.set_yticklabels([])
136+
a.set_yticks([])
137+
a.set_xticklabels([])
138+
a.set_xticks([])
139+
140+
return fig, [tc_ax] + rdm_axes
141+
142+
143+
def unsquareform(a: NDArray) -> NDArray:
144+
"""Helper function; convert squareform to vector
145+
"""
146+
return a[np.nonzero(np.triu(a, k=1))]
147+
148+
149+
def _map_colors(
150+
colored_conditions: Optional[list],
151+
plot_individual_dissimilarities: Optional[bool],
152+
rdms: RDMs
153+
) -> Tuple[bool, Dict[str, NDArray]]:
154+
n_dissimilarity_elements = rdms.dissimilarities.shape[1]
155+
# color mapping from colored conditions
156+
if colored_conditions is not None:
157+
if plot_individual_dissimilarities is None:
158+
plot_individual_dissimilarities = False
159+
sf_conds = [[{c1, c2} for c1 in colored_conditions] for c2 in colored_conditions]
160+
pairwise_conds = unsquareform(np.array(sf_conds))
161+
pairwise_conds_unique = np.unique(pairwise_conds)
162+
color_index = {}
163+
for x in pairwise_conds_unique:
164+
if len(list(x))==2:
165+
key = f'{list(x)[0]} vs {list(x)[1]}'
166+
else:
167+
key = f'{list(x)[0]} vs {list(x)[0]}'
168+
color_index[key] = pairwise_conds==x
169+
else:
170+
color_index = {'': np.array([True]*n_dissimilarity_elements)}
171+
plot_individual_dissimilarities = True
172+
return plot_individual_dissimilarities, color_index

0 commit comments

Comments
 (0)