Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,9 +833,40 @@ def add_channels(self, add_list, force_update_info=False):
class InterpolationMixin(object):
"""Mixin class for Raw, Evoked, Epochs."""

@verbose
def compute_interpolation_matrix(self, mode='accurate', verbose=None):
"""Compute matrix to interpolate bad MEG and EEG channels.

Parameters
----------
mode : str
Either ``'accurate'`` or ``'fast'``, determines the quality of the
Legendre polynomial expansion used for interpolation of MEG
channels.
verbose : bool, str, int, or None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).

Returns
-------
interpolator : Interpolator
Matrix and indices needed for interpolating bad channels.

Notes
-----
.. versionadded:: 0.16.0

Use this function to precompute an interpolation matrix for
:meth:`interpolate_bads`.
"""
from .interpolation import _compute_interpolation_matrix

return _compute_interpolation_matrix(self, mode)

@verbose
def interpolate_bads(self, reset_bads=True, mode='accurate',
verbose=None):
interpolator=None, verbose=None):
"""Interpolate bad MEG and EEG channels.

Operates in place.
Expand All @@ -848,6 +879,10 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
Either ``'accurate'`` or ``'fast'``, determines the quality of the
Legendre polynomial expansion used for interpolation of MEG
channels.
interpolator : bool
Interpolation matrix precomputed with
:meth:`compute_interpolation_matrix`. If specified, ``mode`` is
ignored.
verbose : bool, str, int, or None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
Expand All @@ -861,14 +896,18 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
Notes
-----
.. versionadded:: 0.9.0
"""
from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg

See Also
--------
compute_interpolation_matrix : precompute interpolation matrix
"""
if getattr(self, 'preload', None) is False:
raise ValueError('Data must be preloaded.')

_interpolate_bads_eeg(self)
_interpolate_bads_meg(self, mode=mode)
if interpolator is None:
interpolator = self.compute_interpolation_matrix(mode)

interpolator.apply_in_place(self)

if reset_bads is True:
self.info['bads'] = []
Expand Down
77 changes: 57 additions & 20 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,53 @@
from ..forward import _map_meg_channels


class Interpolator(dict):
"""Precomputed interpolation matrix.

Parameters
----------
eeg : tuple | None
EEG interpolation parameters (output of _interpolate_bads_eeg).
meg : tuple | None
MEG interpolation parameters (output of _interpolate_bads_meg).
"""

def __init__(self, eeg, meg):
dict.__init__(self, eeg=eeg, meg=meg)

def __repr__(self): # noqa: D105
desc = []
if self['eeg'] is not None:
desc.append('%i bad EEG' % (self['eeg'][1].sum(),))
if self['meg'] is not None:
desc.append('%i bad MEG' % (len(self['meg'][1]),))
if not desc:
desc.append('0 bad')
"<Interpolator | %s>" % (', '.join(desc),)

def apply_in_place(self, inst):
"""Dot product of channel mapping matrix to channel data."""
if self['eeg'] is not None:
self._apply_one(inst, *self['eeg'])
if self['meg'] is not None:
self._apply_one(inst, *self['meg'])

@staticmethod
def _apply_one(inst, interpolator, goods_index, bads_index):
from ..io.base import BaseRaw
from ..epochs import BaseEpochs
from ..evoked import Evoked

if isinstance(inst, (BaseRaw, Evoked)):
inst._data[bads_index] = interpolator.dot(inst._data[goods_index])
elif isinstance(inst, BaseEpochs):
inst._data[:, bads_index, :] = np.einsum(
'ij,xjy->xiy', interpolator, inst._data[:, goods_index, :])
else:
raise ValueError('Inputs of type {0} are not supported'
.format(type(inst)))


def _calc_g(cosang, stiffness=4, num_lterms=50):
"""Calculate spherical spline g function between points on a sphere.

Expand All @@ -36,6 +83,13 @@ def _calc_g(cosang, stiffness=4, num_lterms=50):
return legval(cosang, [0] + factors)


def _compute_interpolation_matrix(inst, mode):
"""Implement InterpolationMixin.compute_interpolation_matrix()."""
interp_eeg = _interpolate_bads_eeg(inst)
interp_meg = _interpolate_bads_meg(inst, mode=mode)
return Interpolator(interp_eeg, interp_meg)


def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
"""Compute interpolation matrix based on spherical splines.

Expand Down Expand Up @@ -88,22 +142,6 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
return interpolation


def _do_interp_dots(inst, interpolation, goods_idx, bads_idx):
"""Dot product of channel mapping matrix to channel data."""
from ..io.base import BaseRaw
from ..epochs import BaseEpochs
from ..evoked import Evoked

if isinstance(inst, (BaseRaw, Evoked)):
inst._data[bads_idx] = interpolation.dot(inst._data[goods_idx])
elif isinstance(inst, BaseEpochs):
inst._data[:, bads_idx, :] = np.einsum('ij,xjy->xiy', interpolation,
inst._data[:, goods_idx, :])
else:
raise ValueError('Inputs of type {0} are not supported'
.format(type(inst)))


@verbose
def _interpolate_bads_eeg(inst, verbose=None):
"""Interpolate bad EEG channels.
Expand Down Expand Up @@ -151,7 +189,7 @@ def _interpolate_bads_eeg(inst, verbose=None):
interpolation = _make_interpolation_matrix(pos_good, pos_bad)

logger.info('Interpolating {0} sensors'.format(len(pos_bad)))
_do_interp_dots(inst, interpolation, goods_idx, bads_idx)
return interpolation, goods_idx, bads_idx


@verbose
Expand Down Expand Up @@ -179,13 +217,12 @@ def _interpolate_bads_meg(inst, mode='accurate', verbose=None):
if len(bads_meg) == 0:
picks_bad = []
else:
picks_bad = pick_channels(inst.info['ch_names'], bads_meg,
exclude=[])
picks_bad = pick_channels(inst.info['ch_names'], bads_meg, exclude=[])

# return without doing anything if there are no meg channels
if len(picks_meg) == 0 or len(picks_bad) == 0:
return
info_from = pick_info(inst.info, picks_good)
info_to = pick_info(inst.info, picks_bad)
mapping = _map_meg_channels(info_from, info_to, mode=mode)
_do_interp_dots(inst, mapping, picks_good, picks_bad)
return mapping, picks_good, picks_bad