diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 147ee6dbea6..7336b4e23e4 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -23,6 +23,8 @@ Changelog - Add label center of mass function :func:`mne.Label.center_of_mass` by `Eric Larson`_ + - Added :func:`mne.viz.ica.plot_ica_properties` that allows ploting of independent component properties similar to ``pop_prop`` in EEGLAB. Also :class:`mne.preprocessing.ica.ICA` has :func:`mne.preprocessing.ica.ICA.plot_properties` method now. Added by `Mikołaj Magnuski`_ + BUG ~~~ @@ -59,6 +61,8 @@ API - :func:`mne.concatenate_epochs` and :func:`mne.compute_covariance` now check to see if all :class:`Epochs` instances have the same MEG-to-Head transformation, and errors by default if they do not by `Eric Larson`_ + - Added option to pass a list of axes to :func:`mne.viz.epochs.plot_epochs_image` by `Mikołaj Magnuski`_ + .. _changes_0_12: Version 0.12 @@ -1603,3 +1607,5 @@ of commits): .. _Pablo-Arias: https://github.com/Pablo-Arias .. _Alexander Rudiuk: https://github.com/ARudiuk + +.. _Mikołaj Magnuski: https://github.com/mmagnuski diff --git a/examples/preprocessing/plot_run_ica.py b/examples/preprocessing/plot_run_ica.py index b9114ca56fd..6c724d3d63a 100644 --- a/examples/preprocessing/plot_run_ica.py +++ b/examples/preprocessing/plot_run_ica.py @@ -44,3 +44,4 @@ ecg_inds, scores = ica.find_bads_ecg(ecg_epochs) ica.plot_components(ecg_inds) +ica.plot_properties(epochs, picks=ecg_inds) diff --git a/mne/epochs.py b/mne/epochs.py index 87083470c05..9e462737939 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -37,7 +37,7 @@ from .channels.channels import (ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, InterpolationMixin) from .filter import resample, detrend, FilterMixin -from .event import _read_events_fif +from .event import _read_events_fif, make_fixed_length_events from .fixes import in1d, _get_args from .viz import (plot_epochs, plot_epochs_psd, plot_epochs_psd_topomap, plot_epochs_image, plot_topo_image_epochs) @@ -3112,3 +3112,29 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, _remove_meg_projs(evoked) # remove MEG projectors, they won't apply now logger.info('Created Evoked dataset from %s epochs' % (count,)) return (evoked, mapping) if return_mapping else evoked + + +@verbose +def _segment_raw(raw, segment_length=1., verbose=None, **kwargs): + """Divide continuous raw data into equal-sized + consecutive epochs. + + Parameters + ---------- + raw : instance of Raw + Raw data to divide into segments. + segment_length : float + Length of each segment in seconds. Defaults to 1. + verbose: bool + Whether to report what is being done by printing text. + **kwargs + Any additional keyword arguments are passed to ``Epochs`` constructor. + + Returns + ------- + epochs : instance of ``Epochs`` + Segmented data. + """ + events = make_fixed_length_events(raw, 1, duration=segment_length) + return Epochs(raw, events, event_id=[1], tmin=0., tmax=segment_length, + verbose=verbose, baseline=None, **kwargs) diff --git a/mne/event.py b/mne/event.py index 6683e1507df..446978f89da 100644 --- a/mne/event.py +++ b/mne/event.py @@ -765,6 +765,15 @@ def make_fixed_length_events(raw, id, start=0, stop=None, duration=1., new_events : array The new events. """ + from .io.base import _BaseRaw + if not isinstance(raw, _BaseRaw): + raise ValueError('Input data must be an instance of Raw, got' + ' %s instead.' % (type(raw))) + if not isinstance(id, int): + raise ValueError('id must be an integer') + if not isinstance(duration, (int, float)): + raise ValueError('duration must be an integer of a float, ' + 'got %s instead.' % (type(duration))) start = raw.time_as_index(start)[0] if stop is not None: stop = raw.time_as_index(stop)[0] @@ -775,8 +784,6 @@ def make_fixed_length_events(raw, id, start=0, stop=None, duration=1., stop = min([stop + raw.first_samp, raw.last_samp + 1]) else: stop = min([stop, len(raw.times)]) - if not isinstance(id, int): - raise ValueError('id must be an integer') # Make sure we don't go out the end of the file: stop -= int(np.ceil(raw.info['sfreq'] * duration)) # This should be inclusive due to how we generally use start and stop... diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 486565f3061..f745e85151c 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -35,6 +35,7 @@ from ..epochs import _BaseEpochs from ..viz import (plot_ica_components, plot_ica_scores, plot_ica_sources, plot_ica_overlay) +from ..viz.ica import plot_ica_properties from ..viz.utils import (_prepare_trellis, tight_layout, plt_show, _setup_vmin_vmax) from ..viz.topomap import (_prepare_topo_plot, _check_outlines, @@ -1400,6 +1401,58 @@ def plot_components(self, picks=None, ch_type=None, res=64, layout=None, image_interp=image_interp, head_pos=head_pos) + def plot_properties(self, inst, picks=None, axes=None, dB=True, + plot_std=True, topomap_args=None, image_args=None, + psd_args=None, figsize=None, show=True): + """Display component properties: topography, epochs image, ERP, + power spectrum and epoch variance. + + Parameters + ---------- + inst: instance of Epochs or Raw + The data to use in plotting properties. + picks : int | array-like of int | None + The components to be displayed. If None, plot will show the first + five sources. If more than one components were chosen in the picks, + each one will be plotted in a separate figure. Defaults to None. + axes: list of matplotlib axes | None + List of five matplotlib axes to use in plotting: [topo_axis, + image_axis, erp_axis, spectrum_axis, variance_axis]. If None a new + figure with relevant axes is created. Defaults to None. + dB: bool + Whether to plot spectrum in dB. Defaults to True. + plot_std: bool | float + Whether to plot standard deviation in ERP/ERF and spectrum plots. + Defaults to True, which plots one standard deviation above/below. + If set to float allows to control how many standard deviations are + plotted. For example 2.5 will plot 2.5 standard deviation + above/below. + topomap_args : dict | None + Dictionary of arguments to ``plot_topomap``. If None, doesn't pass + any additional arguments. Defaults to None. + image_args : dict | None + Dictionary of arguments to ``plot_epochs_image``. If None, doesn't + pass any additional arguments. Defaults to None. + psd_args : dict | None + Dictionary of arguments to ``psd_multitaper``. If None, doesn't + pass any additional arguments. Defaults to None. + figsize : array-like of size (2,) | None + Allows to control size of the figure. If None the figure size + defauls to [7., 6.]. + show : bool + Show figure if True. + + Returns + ------- + fig : list + List of matplotlib figures. + """ + return plot_ica_properties(inst, self, picks=picks, axes=axes, + dB=dB, plot_std=plot_std, + topomap_args=topomap_args, + image_args=image_args, psd_args=psd_args, + figsize=figsize, show=show) + def plot_sources(self, inst, picks=None, exclude=None, start=None, stop=None, title=None, show=True, block=False): """Plot estimated latent sources given the unmixing matrix. diff --git a/mne/tests/test_event.py b/mne/tests/test_event.py index 5df909213b1..6587204c973 100644 --- a/mne/tests/test_event.py +++ b/mne/tests/test_event.py @@ -353,6 +353,11 @@ def test_make_fixed_length_events(): # With bad limits (no resulting events) assert_raises(ValueError, make_fixed_length_events, raw, 1, tmin, tmax - 1e-3, duration) + # not raw, bad id or duration + assert_raises(ValueError, make_fixed_length_events, raw, 2.3) + assert_raises(ValueError, make_fixed_length_events, 'not raw', 2) + assert_raises(ValueError, make_fixed_length_events, raw, 23, tmin, tmax, + 'abc') def test_define_events(): diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index 25854583f8e..2ff60d98699 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -29,7 +29,7 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None, vmax=None, colorbar=True, order=None, show=True, units=None, scalings=None, cmap='RdBu_r', - fig=None, overlay_times=None): + fig=None, axes=None, overlay_times=None): """Plot Event Related Potential / Fields image Parameters @@ -71,6 +71,11 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None, Figure instance to draw the image to. Figure must contain two axes for drawing the single trials and evoked responses. If None a new figure is created. Defaults to None. + axes : list of matplotlib axes | None + List of axes instances to draw the image, erp and colorbar to. + Must be of length three if colorbar is True (with the last list element + being the colorbar axes) or two if colorbar is False. If both fig and + axes are passed an error is raised. Defaults to None. overlay_times : array-like, shape (n_epochs,) | None If not None the parameter is interpreted as time instants in seconds and is added to the image. It is typically useful to display reaction @@ -95,8 +100,20 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None, raise ValueError('Scalings and units must have the same keys.') picks = np.atleast_1d(picks) - if fig is not None and len(picks) > 1: + if (fig is not None or axes is not None) and len(picks) > 1: raise ValueError('Only single pick can be drawn to a figure.') + if axes is not None: + if fig is not None: + raise ValueError('Both figure and axes were passed, please' + 'decide between the two.') + from .utils import _validate_if_list_of_axes + oblig_len = 3 if colorbar else 2 + _validate_if_list_of_axes(axes, obligatory_len=oblig_len) + ax1, ax2 = axes[:2] + # if axes were passed - we ignore fig param and get figure from axes + fig = ax1.get_figure() + if colorbar: + ax3 = axes[-1] evoked = epochs.average(picks) data = epochs.get_data()[:, picks, :] scale_vmin = True if vmin is None else False @@ -154,7 +171,11 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None, this_data = ndimage.gaussian_filter1d(this_data, sigma=sigma, axis=0) plt.figure(this_fig.number) - ax1 = plt.subplot2grid((3, 10), (0, 0), colspan=9, rowspan=2) + if axes is None: + ax1 = plt.subplot2grid((3, 10), (0, 0), colspan=9, rowspan=2) + ax2 = plt.subplot2grid((3, 10), (2, 0), colspan=9, rowspan=1) + if colorbar: + ax3 = plt.subplot2grid((3, 10), (0, 9), colspan=1, rowspan=3) if scale_vmin: vmin *= scalings[ch_type] if scale_vmax: @@ -167,9 +188,6 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None, if this_overlay_times is not None: plt.plot(1e3 * this_overlay_times, 0.5 + np.arange(len(this_data)), 'k', linewidth=2) - ax2 = plt.subplot2grid((3, 10), (2, 0), colspan=9, rowspan=1) - if colorbar: - ax3 = plt.subplot2grid((3, 10), (0, 9), colspan=1, rowspan=3) ax1.set_title(epochs.ch_names[idx]) ax1.set_ylabel('Epochs') ax1.axis('auto') diff --git a/mne/viz/ica.py b/mne/viz/ica.py index 70e93cbc2bb..07e89fdb253 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -15,15 +15,17 @@ from .utils import (tight_layout, _prepare_trellis, _select_bads, _layout_figure, _plot_raw_onscroll, _mouse_click, _helper_raw_resize, _plot_raw_onkey, plt_show) +from .topomap import (_prepare_topo_plot, plot_topomap, _hide_frame, + _plot_ica_topomap) from .raw import _prepare_mne_browse_raw, _plot_raw_traces -from .epochs import _prepare_mne_browse_epochs +from .epochs import _prepare_mne_browse_epochs, plot_epochs_image from .evoked import _butterfly_on_button_press, _butterfly_onpick -from .topomap import _prepare_topo_plot, plot_topomap, _hide_frame from ..utils import warn from ..defaults import _handle_default from ..io.meas_info import create_info from ..io.pick import pick_types from ..externals.six import string_types +from ..time_frequency.psd import psd_multitaper def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, @@ -106,6 +108,251 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None, return fig +def _create_properties_layout(figsize=None): + """creates main figure and axes layout used by plot_ica_properties""" + import matplotlib.pyplot as plt + if figsize is None: + figsize = [7., 6.] + fig = plt.figure(figsize=figsize, facecolor=[0.95] * 3) + ax = list() + ax.append(fig.add_axes([0.08, 0.5, 0.3, 0.45], label='topomap')) + ax.append(fig.add_axes([0.5, 0.6, 0.45, 0.35], label='image')) + ax.append(fig.add_axes([0.5, 0.5, 0.45, 0.1], label='erp')) + ax.append(fig.add_axes([0.08, 0.1, 0.32, 0.3], label='spectrum')) + ax.append(fig.add_axes([0.5, 0.1, 0.45, 0.25], label='variance')) + return fig, ax + + +def plot_ica_properties(inst, ica, picks=None, axes=None, dB=True, + plot_std=True, topomap_args=None, image_args=None, + psd_args=None, figsize=None, show=True): + """Display component properties: topography, epochs image, ERP/ERF, + power spectrum and epoch variance. + + Parameters + ---------- + inst: instance of Epochs or Raw + The data to use in plotting properties. + ica : instance of mne.preprocessing.ICA + The ICA solution. + picks : int | array-like of int | None + The components to be displayed. If None, plot will show the first + five sources. If more than one components were chosen in the picks, + each one will be plotted in a separate figure. Defaults to None. + axes: list of matplotlib axes | None + List of five matplotlib axes to use in plotting: [topomap_axis, + image_axis, erp_axis, spectrum_axis, variance_axis]. If None a new + figure with relevant axes is created. Defaults to None. + dB: bool + Whether to plot spectrum in dB. Defaults to True. + plot_std: bool | float + Whether to plot standard deviation in ERP/ERF and spectrum plots. + Defaults to True, which plots one standard deviation above/below. + If set to float allows to control how many standard deviations are + plotted. For example 2.5 will plot 2.5 standard deviation above/below. + topomap_args : dict | None + Dictionary of arguments to ``plot_topomap``. If None, doesn't pass any + additional arguments. Defaults to None. + image_args : dict | None + Dictionary of arguments to ``plot_epochs_image``. If None, doesn't pass + any additional arguments. Defaults to None. + psd_args : dict | None + Dictionary of arguments to ``psd_multitaper``. If None, doesn't pass + any additional arguments. Defaults to None. + figsize : array-like of size (2,) | None + Allows to control size of the figure. If None, the figure size + defauls to [7., 6.]. + show : bool + Show figure if True. + + Returns + ------- + fig : list + List of matplotlib figures. + """ + from ..io.base import _BaseRaw + from ..epochs import _BaseEpochs + from ..preprocessing import ICA + + if not isinstance(inst, (_BaseRaw, _BaseEpochs)): + raise ValueError('inst should be an instance of Raw or Epochs,' + ' got %s instead.' % type(inst)) + if not isinstance(ica, ICA): + raise ValueError('ica has to be an instance of ICA, ' + 'got %s instead' % type(ica)) + if isinstance(plot_std, bool): + num_std = 1. if plot_std else 0. + elif isinstance(plot_std, (float, int)): + num_std = plot_std + plot_std = True + else: + raise ValueError('plot_std has to be a bool, int or float, ' + 'got %s instead' % type(plot_std)) + + # if no picks given - plot the first 5 components + picks = list(range(min(5, ica.n_components_))) if picks is None else picks + picks = [picks] if isinstance(picks, int) else picks + if axes is None: + fig, axes = _create_properties_layout(figsize=figsize) + else: + if len(picks) > 1: + raise ValueError('Only a single pick can be drawn ' + 'to a set of axes.') + from .utils import _validate_if_list_of_axes + _validate_if_list_of_axes(axes, obligatory_len=5) + fig = axes[0].get_figure() + psd_args = dict() if psd_args is None else psd_args + topomap_args = dict() if topomap_args is None else topomap_args + image_args = dict() if image_args is None else image_args + for d in (psd_args, topomap_args, image_args): + if not isinstance(d, dict): + raise ValueError('topomap_args, image_args and psd_args have to be' + ' dictionaries, got %s instead.' % type(d)) + if dB is not None and isinstance(dB, bool) is False: + raise ValueError('dB should be bool, got %s instead' % + type(dB)) + + # calculations + # ------------ + plot_line_at_zero = False + if isinstance(inst, _BaseRaw): + # break up continuous signal into segments + from ..epochs import _segment_raw + inst = _segment_raw(inst, segment_length=2., verbose=False, + preload=True) + if inst.times[0] < 0. and inst.times[-1] > 0.: + plot_line_at_zero = True + + epochs_src = ica.get_sources(inst) + ica_data = np.swapaxes(epochs_src.get_data()[:, picks, :], 0, 1) + + # spectrum + Nyquist = inst.info['sfreq'] / 2. + if 'fmax' not in psd_args: + psd_args['fmax'] = min(inst.info['lowpass'] * 1.25, Nyquist) + plot_lowpass_edge = inst.info['lowpass'] < Nyquist and ( + psd_args['fmax'] > inst.info['lowpass']) + psds, freqs = psd_multitaper(epochs_src, picks=picks, **psd_args) + + def set_title_and_labels(ax, title, xlab, ylab): + if title: + ax.set_title(title) + if xlab: + ax.set_xlabel(xlab) + if ylab: + ax.set_ylabel(ylab) + ax.axis('auto') + ax.tick_params('both', labelsize=8) + ax.axis('tight') + + all_fig = list() + # the rest is component-specific + for idx, pick in enumerate(picks): + if idx > 0: + fig, axes = _create_properties_layout(figsize=figsize) + + # spectrum + this_psd = psds[:, idx, :] + if dB: + this_psd = 10 * np.log10(this_psd) + psds_mean = this_psd.mean(axis=0) + diffs = this_psd - psds_mean + # the distribution of power for each frequency bin is highly + # skewed so we calculate std for values below and above average + # separately - this is used for fill_between shade + spectrum_std = [ + [np.sqrt((d[d < 0] ** 2).mean(axis=0)) for d in diffs.T], + [np.sqrt((d[d > 0] ** 2).mean(axis=0)) for d in diffs.T]] + spectrum_std = np.array(spectrum_std) * num_std + + # erp std + if plot_std: + erp = ica_data[idx].mean(axis=0) + diffs = ica_data[idx] - erp + erp_std = [ + [np.sqrt((d[d < 0] ** 2).mean(axis=0)) for d in diffs.T], + [np.sqrt((d[d > 0] ** 2).mean(axis=0)) for d in diffs.T]] + erp_std = np.array(erp_std) * num_std + + # epoch variance + epoch_var = np.var(ica_data[idx], axis=1) + + # plotting + # -------- + # component topomap + _plot_ica_topomap(ica, pick, show=False, axes=axes[0], **topomap_args) + + # image and erp + plot_epochs_image(epochs_src, picks=pick, axes=axes[1:3], + colorbar=False, show=False, **image_args) + + # spectrum + axes[3].plot(freqs, psds_mean, color='k') + if plot_std: + axes[3].fill_between(freqs, psds_mean - spectrum_std[0], + psds_mean + spectrum_std[1], + color='k', alpha=.15) + if plot_lowpass_edge: + axes[3].axvline(inst.info['lowpass'], lw=2, linestyle='--', + color='k', alpha=0.15) + + # epoch variance + axes[4].scatter(range(len(epoch_var)), epoch_var, alpha=0.5, + facecolor=[0, 0, 0], lw=0) + + # aesthetics + # ---------- + axes[0].set_title('IC #{0:0>3}'.format(pick)) + + set_title_and_labels(axes[1], 'epochs image and ERP/ERF', [], 'Epochs') + + # erp + set_title_and_labels(axes[2], [], 'time', 'AU') + # line color and std + axes[2].lines[0].set_color('k') + if plot_std: + erp_xdata = axes[2].lines[0].get_data()[0] + axes[2].fill_between(erp_xdata, erp - erp_std[0], + erp + erp_std[1], color='k', alpha=.15) + axes[2].autoscale(enable=True, axis='y') + axes[2].axis('auto') + axes[2].set_xlim(erp_xdata[[0, -1]]) + # remove half of yticks if more than 5 + yt = axes[2].get_yticks() + if len(yt) > 5: + yt = yt[::2] + axes[2].yaxis.set_ticks(yt) + + if not plot_line_at_zero: + xlims = [1e3 * inst.times[0], 1e3 * inst.times[-1]] + for k, ax in enumerate(axes[1:3]): + ax.lines[k].remove() + ax.set_xlim(xlims) + + # remove xticks - erp plot shows xticks for both image and erp plot + axes[1].xaxis.set_ticks([]) + yt = axes[1].get_yticks() + axes[1].yaxis.set_ticks(yt[1:]) + axes[1].set_ylim([-0.5, ica_data.shape[1] + 0.5]) + + # spectrum + ylabel = 'dB' if dB else 'power' + set_title_and_labels(axes[3], 'spectrum', 'frequency', ylabel) + axes[3].yaxis.labelpad = 0 + axes[3].set_xlim(freqs[[0, -1]]) + ylim = axes[3].get_ylim() + air = np.diff(ylim)[0] * 0.1 + axes[3].set_ylim(ylim[0] - air, ylim[1] + air) + + # epoch variance + set_title_and_labels(axes[4], 'epochs variance', 'epoch', 'AU') + + all_fig.append(fig) + + plt_show(show) + return all_fig + + def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, labels=None): """Plot average over epochs in ICA space diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 3149fc65422..4c3aa31a649 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -6,11 +6,12 @@ import os.path as op import warnings -from numpy.testing import assert_raises +from numpy.testing import assert_raises, assert_equal from mne import io, read_events, Epochs, read_cov from mne import pick_types from mne.utils import run_tests_if_main, requires_sklearn +from mne.viz.ica import _create_properties_layout, plot_ica_properties from mne.viz.utils import _fake_click from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs @@ -71,6 +72,59 @@ def test_plot_ica_components(): plt.close('all') +@requires_sklearn +def test_plot_ica_properties(): + """Test plotting of ICA properties + """ + import matplotlib.pyplot as plt + + raw = _get_raw(preload=True) + events = _get_events() + picks = _get_picks(raw)[:6] + pick_names = [raw.ch_names[k] for k in picks] + raw.pick_channels(pick_names) + + with warnings.catch_warnings(record=True): # bad proj + epochs = Epochs(raw, events[:10], event_id, tmin, tmax, + baseline=(None, 0), preload=True) + + ica = ICA(noise_cov=read_cov(cov_fname), n_components=2, + max_pca_components=2, n_pca_components=2) + with warnings.catch_warnings(record=True): # bad proj + ica.fit(raw) + + # test _create_properties_layout + fig, ax = _create_properties_layout() + assert_equal(len(ax), 5) + + topoargs = dict(topomap_args={'res': 10}) + ica.plot_properties(raw, picks=0, **topoargs) + ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs) + ica.plot_properties(epochs, picks=1, image_args={'sigma': 1.5}, + topomap_args={'res': 10, 'colorbar': True}, + psd_args={'fmax': 65.}, plot_std=False, + figsize=[4.5, 4.5]) + plt.close('all') + + assert_raises(ValueError, ica.plot_properties, epochs, dB=list('abc')) + assert_raises(ValueError, ica.plot_properties, epochs, plot_std=[]) + assert_raises(ValueError, ica.plot_properties, ica) + assert_raises(ValueError, ica.plot_properties, [0.2]) + assert_raises(ValueError, plot_ica_properties, epochs, epochs) + assert_raises(ValueError, ica.plot_properties, epochs, + psd_args='not dict') + + fig, ax = plt.subplots(2, 3) + ax = ax.ravel()[:-1] + ica.plot_properties(epochs, picks=1, axes=ax) + fig = ica.plot_properties(raw, picks=[0, 1], **topoargs) + assert_equal(len(fig), 2) + assert_raises(ValueError, plot_ica_properties, epochs, ica, picks=[0, 1], + axes=ax) + assert_raises(ValueError, ica.plot_properties, epochs, axes='not axes') + plt.close('all') + + @requires_sklearn def test_plot_ica_sources(): """Test plotting of ICA panel diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 9afc9ddf760..12efa509e23 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -8,7 +8,8 @@ from nose.tools import assert_true, assert_raises from numpy.testing import assert_allclose -from mne.viz.utils import compare_fiff, _fake_click, _compute_scalings +from mne.viz.utils import (compare_fiff, _fake_click, _compute_scalings, + _validate_if_list_of_axes) from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.utils import run_tests_if_main from mne.io import read_raw_fif @@ -117,4 +118,23 @@ def test_auto_scale(): dict(grad='auto'), epochs) +def test_validate_if_list_of_axes(): + import matplotlib.pyplot as plt + fig, ax = plt.subplots(2, 2) + assert_raises(ValueError, _validate_if_list_of_axes, ax) + ax_flat = ax.ravel() + ax = ax.ravel().tolist() + _validate_if_list_of_axes(ax_flat) + _validate_if_list_of_axes(ax_flat, 4) + assert_raises(ValueError, _validate_if_list_of_axes, ax_flat, 5) + assert_raises(ValueError, _validate_if_list_of_axes, ax, 3) + assert_raises(ValueError, _validate_if_list_of_axes, 'error') + assert_raises(ValueError, _validate_if_list_of_axes, ['error'] * 2) + assert_raises(ValueError, _validate_if_list_of_axes, ax[0]) + assert_raises(ValueError, _validate_if_list_of_axes, ax, 3) + ax_flat[2] = 23 + assert_raises(ValueError, _validate_if_list_of_axes, ax_flat) + _validate_if_list_of_axes(ax, 4) + + run_tests_if_main() diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index e23601e9d2a..0bde35d8159 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -703,6 +703,56 @@ def _inside_contour(pos, contour): return check_mask +def _plot_ica_topomap(ica, idx=0, ch_type=None, res=64, layout=None, + vmin=None, vmax=None, cmap='RdBu_r', colorbar=False, + title=None, show=True, outlines='head', contours=6, + image_interp='bilinear', head_pos=None, axes=None): + """plot single ica map to axes""" + import matplotlib as mpl + from ..channels import _get_ch_type + from ..preprocessing.ica import _get_ica_map + + if ica.info is None: + raise RuntimeError('The ICA\'s measurement info is missing. Please ' + 'fit the ICA or add the corresponding info object.') + if not isinstance(axes, mpl.axes.Axes): + raise ValueError('axis has to be an instance of matplotlib Axes, ' + 'got %s instead.' % type(axes)) + ch_type = _get_ch_type(ica, ch_type) + + data = _get_ica_map(ica, components=idx) + data_picks, pos, merge_grads, names, _ = _prepare_topo_plot( + ica, ch_type, layout) + pos, outlines = _check_outlines(pos, outlines, head_pos) + if outlines not in (None, 'head'): + image_mask, pos = _make_image_mask(outlines, pos, res) + else: + image_mask = None + + data = np.atleast_2d(data) + data = data[:, data_picks] + + if merge_grads: + from ..channels.layout import _merge_grad_data + data = _merge_grad_data(data) + axes.set_title('IC #%03d' % idx, fontsize=12) + vmin_, vmax_ = _setup_vmin_vmax(data, vmin, vmax) + im = plot_topomap(data.ravel(), pos, vmin=vmin_, vmax=vmax_, + res=res, axes=axes, cmap=cmap, outlines=outlines, + image_mask=image_mask, contours=contours, + image_interp=image_interp, show=show)[0] + if colorbar: + import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid import make_axes_locatable + divider = make_axes_locatable(axes) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = plt.colorbar(im, cax=cax, format='%3.2f', cmap=cmap) + cbar.ax.tick_params(labelsize=12) + cbar.set_ticks((vmin_, vmax_)) + cbar.ax.set_title('AU', fontsize=10) + _hide_frame(axes) + + def plot_ica_components(ica, picks=None, ch_type=None, res=64, layout=None, vmin=None, vmax=None, cmap='RdBu_r', sensors=True, colorbar=False, title=None, diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 07f165014d1..978a48ce138 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -109,6 +109,32 @@ def _check_delayed_ssp(container): raise RuntimeError('No projs found in evoked.') +def _validate_if_list_of_axes(axes, obligatory_len=None): + """ Helper function that validates whether input is a list/array of axes""" + import matplotlib as mpl + if obligatory_len is not None and not isinstance(obligatory_len, int): + raise ValueError('obligatory_len must be None or int, got %d', + 'instead' % type(obligatory_len)) + if not isinstance(axes, (list, np.ndarray)): + raise ValueError('axes must be a list or numpy array of matplotlib ' + 'axes objects, got %s instead.' % type(axes)) + if isinstance(axes, np.ndarray) and axes.ndim > 1: + raise ValueError('if input is a numpy array, it must be ' + 'one-dimensional. The received numpy array has %d ' + 'dimensions however. Try using ravel or flatten ' + 'method of the array.' % axes.ndim) + is_correct_type = np.array([isinstance(x, mpl.axes.Axes) + for x in axes]) + if not np.all(is_correct_type): + first_bad = np.where(np.logical_not(is_correct_type))[0][0] + raise ValueError('axes must be a list or numpy array of matplotlib ' + 'axes objects while one of the list elements is ' + '%s.' % type(axes[first_bad])) + if obligatory_len is not None and not len(axes) == obligatory_len: + raise ValueError('axes must be a list/array of length %d, while the' + ' length is %d' % (obligatory_len, len(axes))) + + def mne_analyze_colormap(limits=[5, 10, 15], format='mayavi'): """Return a colormap similar to that used by mne_analyze diff --git a/tutorials/plot_artifacts_correction_ica.py b/tutorials/plot_artifacts_correction_ica.py index 8fe0078d597..73681141114 100644 --- a/tutorials/plot_artifacts_correction_ica.py +++ b/tutorials/plot_artifacts_correction_ica.py @@ -49,9 +49,14 @@ method = 'fastica' # for comparison with EEGLAB try "extended-infomax" here decim = 3 # we need sufficient statistics, not all time points -> save time +# we will also set state of the random number generator - ICA is a +# non-deterministic algorithm, but we want to have the same decomposition +# and the same order of components each time this tutorial is run +random_state = 23 + ############################################################################### # Define the ICA object instance -ica = ICA(n_components=n_components, method=method) +ica = ICA(n_components=n_components, method=method, random_state=random_state) print(ica) ############################################################################### @@ -67,6 +72,26 @@ ica.plot_components() # can you see some potential bad guys? +############################################################################### +# Component properties +# -------------------- +# +# Let's take a closer look at three potential candidates for artifact-related +# components: IC 12, IC 15 and IC 21 + +# first, component 12: +ica.plot_properties(raw, picks=12, dB=True) + +############################################################################### +# it looks like a blink component, but because the data were filtered +# the spectrum plot is not very informative, let's change that: +ica.plot_properties(raw, picks=12, dB=True, psd_args={'fmax': 35.}) + +############################################################################### +# now let's inspect properties of components 15 and 21 (cardiac activity): +ica.plot_properties(raw, picks=[15, 21], psd_args={'fmax': 35.}) + + ############################################################################### # Advanced artifact detection # --------------------------- @@ -88,8 +113,15 @@ ica.plot_sources(eog_average, exclude=eog_inds) # look at source time course ############################################################################### -# That component is also showing a prototypical average vertical EOG time -# course. +# We can take a look at the properties of that component again, now using the +# data epoched with respect to EOG events. +# We will also use a little bit of smoothing along the trials axis in the +# epochs image: +ica.plot_properties(eog_epochs, picks=eog_inds, dB=True, + psd_args={'fmax': 35.}, image_args={'sigma': 1.}) + +############################################################################### +# That component is showing a prototypical average vertical EOG time course. # # Pay attention to the labels, a customized read-out of the # :attr:`ica.labels_ `