Skip to content

Commit ecebf20

Browse files
ENH: Interpolator to cache interpolation matrix
1 parent b5a2a96 commit ecebf20

File tree

2 files changed

+90
-25
lines changed

2 files changed

+90
-25
lines changed

mne/channels/channels.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -833,9 +833,40 @@ def add_channels(self, add_list, force_update_info=False):
833833
class InterpolationMixin(object):
834834
"""Mixin class for Raw, Evoked, Epochs."""
835835

836+
@verbose
837+
def compute_interpolation_matrix(self, mode='accurate', verbose=None):
838+
"""Compute matrix to interpolate bad MEG and EEG channels.
839+
840+
Parameters
841+
----------
842+
mode : str
843+
Either ``'accurate'`` or ``'fast'``, determines the quality of the
844+
Legendre polynomial expansion used for interpolation of MEG
845+
channels.
846+
verbose : bool, str, int, or None
847+
If not None, override default verbose level (see
848+
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
849+
for more).
850+
851+
Returns
852+
-------
853+
interpolator : Interpolator
854+
Matrix and indices needed for interpolating bad channels.
855+
856+
Notes
857+
-----
858+
.. versionadded:: 0.16.0
859+
860+
Use this function to precompute an interpolation matrix for
861+
:meth:`interpolate_bads`.
862+
"""
863+
from .interpolation import _compute_interpolation_matrix
864+
865+
return _compute_interpolation_matrix(self, mode)
866+
836867
@verbose
837868
def interpolate_bads(self, reset_bads=True, mode='accurate',
838-
verbose=None):
869+
interpolator=None, verbose=None):
839870
"""Interpolate bad MEG and EEG channels.
840871
841872
Operates in place.
@@ -848,6 +879,10 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
848879
Either ``'accurate'`` or ``'fast'``, determines the quality of the
849880
Legendre polynomial expansion used for interpolation of MEG
850881
channels.
882+
interpolator : bool
883+
Interpolation matrix precomputed with
884+
:meth:`compute_interpolation_matrix`. If specified, ``mode`` is
885+
ignored.
851886
verbose : bool, str, int, or None
852887
If not None, override default verbose level (see
853888
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
@@ -861,14 +896,18 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
861896
Notes
862897
-----
863898
.. versionadded:: 0.9.0
864-
"""
865-
from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg
866899
900+
See Also
901+
--------
902+
compute_interpolation_matrix : precompute interpolation matrix
903+
"""
867904
if getattr(self, 'preload', None) is False:
868905
raise ValueError('Data must be preloaded.')
869906

870-
_interpolate_bads_eeg(self)
871-
_interpolate_bads_meg(self, mode=mode)
907+
if interpolator is None:
908+
interpolator = self.compute_interpolation_matrix(mode)
909+
910+
interpolator.apply_in_place(self)
872911

873912
if reset_bads is True:
874913
self.info['bads'] = []

mne/channels/interpolation.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,42 @@
1313
from ..forward import _map_meg_channels
1414

1515

16+
class Interpolator(dict):
17+
"""Precomputed interpolation matrix.
18+
19+
Parameters
20+
----------
21+
eeg : tuple | None
22+
EEG interpolation parameters (output of _interpolate_bads_eeg).
23+
meg : tuple | None
24+
MEG interpolation parameters (output of _interpolate_bads_meg).
25+
"""
26+
def __init__(self, eeg, meg):
27+
dict.__init__(self, eeg=eeg, meg=meg)
28+
29+
def apply_in_place(self, inst):
30+
"""Dot product of channel mapping matrix to channel data."""
31+
if self['eeg'] is not None:
32+
self._apply_one(inst, *self['eeg'])
33+
if self['meg'] is not None:
34+
self._apply_one(inst, *self['meg'])
35+
36+
@staticmethod
37+
def _apply_one(inst, interpolator, goods_index, bads_index):
38+
from ..io.base import BaseRaw
39+
from ..epochs import BaseEpochs
40+
from ..evoked import Evoked
41+
42+
if isinstance(inst, (BaseRaw, Evoked)):
43+
inst._data[bads_index] = interpolator.dot(inst._data[goods_index])
44+
elif isinstance(inst, BaseEpochs):
45+
inst._data[:, bads_index, :] = np.einsum(
46+
'ij,xjy->xiy', interpolator, inst._data[:, goods_index, :])
47+
else:
48+
raise ValueError('Inputs of type {0} are not supported'
49+
.format(type(inst)))
50+
51+
1652
def _calc_g(cosang, stiffness=4, num_lterms=50):
1753
"""Calculate spherical spline g function between points on a sphere.
1854
@@ -36,6 +72,13 @@ def _calc_g(cosang, stiffness=4, num_lterms=50):
3672
return legval(cosang, [0] + factors)
3773

3874

75+
def _compute_interpolation_matrix(inst, mode):
76+
"""Implement InterpolationMixin.compute_interpolation_matrix()."""
77+
interp_eeg = _interpolate_bads_eeg(inst)
78+
interp_meg = _interpolate_bads_meg(inst, mode=mode)
79+
return Interpolator(interp_eeg, interp_meg)
80+
81+
3982
def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
4083
"""Compute interpolation matrix based on spherical splines.
4184
@@ -88,22 +131,6 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
88131
return interpolation
89132

90133

91-
def _do_interp_dots(inst, interpolation, goods_idx, bads_idx):
92-
"""Dot product of channel mapping matrix to channel data."""
93-
from ..io.base import BaseRaw
94-
from ..epochs import BaseEpochs
95-
from ..evoked import Evoked
96-
97-
if isinstance(inst, (BaseRaw, Evoked)):
98-
inst._data[bads_idx] = interpolation.dot(inst._data[goods_idx])
99-
elif isinstance(inst, BaseEpochs):
100-
inst._data[:, bads_idx, :] = np.einsum('ij,xjy->xiy', interpolation,
101-
inst._data[:, goods_idx, :])
102-
else:
103-
raise ValueError('Inputs of type {0} are not supported'
104-
.format(type(inst)))
105-
106-
107134
@verbose
108135
def _interpolate_bads_eeg(inst, verbose=None):
109136
"""Interpolate bad EEG channels.
@@ -151,7 +178,7 @@ def _interpolate_bads_eeg(inst, verbose=None):
151178
interpolation = _make_interpolation_matrix(pos_good, pos_bad)
152179

153180
logger.info('Interpolating {0} sensors'.format(len(pos_bad)))
154-
_do_interp_dots(inst, interpolation, goods_idx, bads_idx)
181+
return interpolation, goods_idx, bads_idx
155182

156183

157184
@verbose
@@ -179,13 +206,12 @@ def _interpolate_bads_meg(inst, mode='accurate', verbose=None):
179206
if len(bads_meg) == 0:
180207
picks_bad = []
181208
else:
182-
picks_bad = pick_channels(inst.info['ch_names'], bads_meg,
183-
exclude=[])
209+
picks_bad = pick_channels(inst.info['ch_names'], bads_meg, exclude=[])
184210

185211
# return without doing anything if there are no meg channels
186212
if len(picks_meg) == 0 or len(picks_bad) == 0:
187213
return
188214
info_from = pick_info(inst.info, picks_good)
189215
info_to = pick_info(inst.info, picks_bad)
190216
mapping = _map_meg_channels(info_from, info_to, mode=mode)
191-
_do_interp_dots(inst, mapping, picks_good, picks_bad)
217+
return mapping, picks_good, picks_bad

0 commit comments

Comments
 (0)