1313from ..forward import _map_meg_channels
1414
1515
16+ class Interpolator (dict ):
17+ """Precomputed interpolation matrix.
18+
19+ Parameters
20+ ----------
21+ eeg : tuple | None
22+ EEG interpolation parameters (output of _interpolate_bads_eeg).
23+ meg : tuple | None
24+ MEG interpolation parameters (output of _interpolate_bads_meg).
25+ """
26+ def __init__ (self , eeg , meg ):
27+ dict .__init__ (self , eeg = eeg , meg = meg )
28+
29+ def apply_in_place (self , inst ):
30+ """Dot product of channel mapping matrix to channel data."""
31+ if self ['eeg' ] is not None :
32+ self ._apply_one (inst , * self ['eeg' ])
33+ if self ['meg' ] is not None :
34+ self ._apply_one (inst , * self ['meg' ])
35+
36+ @staticmethod
37+ def _apply_one (inst , interpolator , goods_index , bads_index ):
38+ from ..io .base import BaseRaw
39+ from ..epochs import BaseEpochs
40+ from ..evoked import Evoked
41+
42+ if isinstance (inst , (BaseRaw , Evoked )):
43+ inst ._data [bads_index ] = interpolator .dot (inst ._data [goods_index ])
44+ elif isinstance (inst , BaseEpochs ):
45+ inst ._data [:, bads_index , :] = np .einsum (
46+ 'ij,xjy->xiy' , interpolator , inst ._data [:, goods_index , :])
47+ else :
48+ raise ValueError ('Inputs of type {0} are not supported'
49+ .format (type (inst )))
50+
51+
1652def _calc_g (cosang , stiffness = 4 , num_lterms = 50 ):
1753 """Calculate spherical spline g function between points on a sphere.
1854
@@ -36,6 +72,13 @@ def _calc_g(cosang, stiffness=4, num_lterms=50):
3672 return legval (cosang , [0 ] + factors )
3773
3874
75+ def _compute_interpolation_matrix (inst , mode ):
76+ """Implement InterpolationMixin.compute_interpolation_matrix()."""
77+ interp_eeg = _interpolate_bads_eeg (inst )
78+ interp_meg = _interpolate_bads_meg (inst , mode = mode )
79+ return Interpolator (interp_eeg , interp_meg )
80+
81+
3982def _make_interpolation_matrix (pos_from , pos_to , alpha = 1e-5 ):
4083 """Compute interpolation matrix based on spherical splines.
4184
@@ -88,22 +131,6 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
88131 return interpolation
89132
90133
91- def _do_interp_dots (inst , interpolation , goods_idx , bads_idx ):
92- """Dot product of channel mapping matrix to channel data."""
93- from ..io .base import BaseRaw
94- from ..epochs import BaseEpochs
95- from ..evoked import Evoked
96-
97- if isinstance (inst , (BaseRaw , Evoked )):
98- inst ._data [bads_idx ] = interpolation .dot (inst ._data [goods_idx ])
99- elif isinstance (inst , BaseEpochs ):
100- inst ._data [:, bads_idx , :] = np .einsum ('ij,xjy->xiy' , interpolation ,
101- inst ._data [:, goods_idx , :])
102- else :
103- raise ValueError ('Inputs of type {0} are not supported'
104- .format (type (inst )))
105-
106-
107134@verbose
108135def _interpolate_bads_eeg (inst , verbose = None ):
109136 """Interpolate bad EEG channels.
@@ -151,7 +178,7 @@ def _interpolate_bads_eeg(inst, verbose=None):
151178 interpolation = _make_interpolation_matrix (pos_good , pos_bad )
152179
153180 logger .info ('Interpolating {0} sensors' .format (len (pos_bad )))
154- _do_interp_dots ( inst , interpolation , goods_idx , bads_idx )
181+ return interpolation , goods_idx , bads_idx
155182
156183
157184@verbose
@@ -179,13 +206,12 @@ def _interpolate_bads_meg(inst, mode='accurate', verbose=None):
179206 if len (bads_meg ) == 0 :
180207 picks_bad = []
181208 else :
182- picks_bad = pick_channels (inst .info ['ch_names' ], bads_meg ,
183- exclude = [])
209+ picks_bad = pick_channels (inst .info ['ch_names' ], bads_meg , exclude = [])
184210
185211 # return without doing anything if there are no meg channels
186212 if len (picks_meg ) == 0 or len (picks_bad ) == 0 :
187213 return
188214 info_from = pick_info (inst .info , picks_good )
189215 info_to = pick_info (inst .info , picks_bad )
190216 mapping = _map_meg_channels (info_from , info_to , mode = mode )
191- _do_interp_dots ( inst , mapping , picks_good , picks_bad )
217+ return mapping , picks_good , picks_bad
0 commit comments