diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 1dd9a094595..213c39d26a3 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -833,9 +833,40 @@ def add_channels(self, add_list, force_update_info=False): class InterpolationMixin(object): """Mixin class for Raw, Evoked, Epochs.""" + @verbose + def compute_interpolation_matrix(self, mode='accurate', verbose=None): + """Compute matrix to interpolate bad MEG and EEG channels. + + Parameters + ---------- + mode : str + Either ``'accurate'`` or ``'fast'``, determines the quality of the + Legendre polynomial expansion used for interpolation of MEG + channels. + verbose : bool, str, int, or None + If not None, override default verbose level (see + :func:`mne.verbose` and :ref:`Logging documentation ` + for more). + + Returns + ------- + interpolator : Interpolator + Matrix and indices needed for interpolating bad channels. + + Notes + ----- + .. versionadded:: 0.16.0 + + Use this function to precompute an interpolation matrix for + :meth:`interpolate_bads`. + """ + from .interpolation import _compute_interpolation_matrix + + return _compute_interpolation_matrix(self, mode) + @verbose def interpolate_bads(self, reset_bads=True, mode='accurate', - verbose=None): + interpolator=None, verbose=None): """Interpolate bad MEG and EEG channels. Operates in place. @@ -848,6 +879,10 @@ def interpolate_bads(self, reset_bads=True, mode='accurate', Either ``'accurate'`` or ``'fast'``, determines the quality of the Legendre polynomial expansion used for interpolation of MEG channels. + interpolator : bool + Interpolation matrix precomputed with + :meth:`compute_interpolation_matrix`. If specified, ``mode`` is + ignored. verbose : bool, str, int, or None If not None, override default verbose level (see :func:`mne.verbose` and :ref:`Logging documentation ` @@ -861,14 +896,18 @@ def interpolate_bads(self, reset_bads=True, mode='accurate', Notes ----- .. versionadded:: 0.9.0 - """ - from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg + See Also + -------- + compute_interpolation_matrix : precompute interpolation matrix + """ if getattr(self, 'preload', None) is False: raise ValueError('Data must be preloaded.') - _interpolate_bads_eeg(self) - _interpolate_bads_meg(self, mode=mode) + if interpolator is None: + interpolator = self.compute_interpolation_matrix(mode) + + interpolator.apply_in_place(self) if reset_bads is True: self.info['bads'] = [] diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index 9bfc77a421a..5ae4941b3c6 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -13,6 +13,53 @@ from ..forward import _map_meg_channels +class Interpolator(dict): + """Precomputed interpolation matrix. + + Parameters + ---------- + eeg : tuple | None + EEG interpolation parameters (output of _interpolate_bads_eeg). + meg : tuple | None + MEG interpolation parameters (output of _interpolate_bads_meg). + """ + + def __init__(self, eeg, meg): + dict.__init__(self, eeg=eeg, meg=meg) + + def __repr__(self): # noqa: D105 + desc = [] + if self['eeg'] is not None: + desc.append('%i bad EEG' % (self['eeg'][1].sum(),)) + if self['meg'] is not None: + desc.append('%i bad MEG' % (len(self['meg'][1]),)) + if not desc: + desc.append('0 bad') + "" % (', '.join(desc),) + + def apply_in_place(self, inst): + """Dot product of channel mapping matrix to channel data.""" + if self['eeg'] is not None: + self._apply_one(inst, *self['eeg']) + if self['meg'] is not None: + self._apply_one(inst, *self['meg']) + + @staticmethod + def _apply_one(inst, interpolator, goods_index, bads_index): + from ..io.base import BaseRaw + from ..epochs import BaseEpochs + from ..evoked import Evoked + + if isinstance(inst, (BaseRaw, Evoked)): + inst._data[bads_index] = interpolator.dot(inst._data[goods_index]) + elif isinstance(inst, BaseEpochs): + inst._data[:, bads_index, :] = np.einsum( + 'ij,xjy->xiy', interpolator, inst._data[:, goods_index, :]) + else: + raise ValueError('Inputs of type {0} are not supported' + .format(type(inst))) + + def _calc_g(cosang, stiffness=4, num_lterms=50): """Calculate spherical spline g function between points on a sphere. @@ -36,6 +83,13 @@ def _calc_g(cosang, stiffness=4, num_lterms=50): return legval(cosang, [0] + factors) +def _compute_interpolation_matrix(inst, mode): + """Implement InterpolationMixin.compute_interpolation_matrix().""" + interp_eeg = _interpolate_bads_eeg(inst) + interp_meg = _interpolate_bads_meg(inst, mode=mode) + return Interpolator(interp_eeg, interp_meg) + + def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5): """Compute interpolation matrix based on spherical splines. @@ -88,22 +142,6 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5): return interpolation -def _do_interp_dots(inst, interpolation, goods_idx, bads_idx): - """Dot product of channel mapping matrix to channel data.""" - from ..io.base import BaseRaw - from ..epochs import BaseEpochs - from ..evoked import Evoked - - if isinstance(inst, (BaseRaw, Evoked)): - inst._data[bads_idx] = interpolation.dot(inst._data[goods_idx]) - elif isinstance(inst, BaseEpochs): - inst._data[:, bads_idx, :] = np.einsum('ij,xjy->xiy', interpolation, - inst._data[:, goods_idx, :]) - else: - raise ValueError('Inputs of type {0} are not supported' - .format(type(inst))) - - @verbose def _interpolate_bads_eeg(inst, verbose=None): """Interpolate bad EEG channels. @@ -151,7 +189,7 @@ def _interpolate_bads_eeg(inst, verbose=None): interpolation = _make_interpolation_matrix(pos_good, pos_bad) logger.info('Interpolating {0} sensors'.format(len(pos_bad))) - _do_interp_dots(inst, interpolation, goods_idx, bads_idx) + return interpolation, goods_idx, bads_idx @verbose @@ -179,8 +217,7 @@ def _interpolate_bads_meg(inst, mode='accurate', verbose=None): if len(bads_meg) == 0: picks_bad = [] else: - picks_bad = pick_channels(inst.info['ch_names'], bads_meg, - exclude=[]) + picks_bad = pick_channels(inst.info['ch_names'], bads_meg, exclude=[]) # return without doing anything if there are no meg channels if len(picks_meg) == 0 or len(picks_bad) == 0: @@ -188,4 +225,4 @@ def _interpolate_bads_meg(inst, mode='accurate', verbose=None): info_from = pick_info(inst.info, picks_good) info_to = pick_info(inst.info, picks_bad) mapping = _map_meg_channels(info_from, info_to, mode=mode) - _do_interp_dots(inst, mapping, picks_good, picks_bad) + return mapping, picks_good, picks_bad