@@ -699,24 +699,26 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
699699 specified through `low` and `high`. Support broadcasting.
700700 """
701701 np .testing .assert_equal (low .shape , high .shape )
702- ni , _ , nk = arr .shape [:axis ], arr .shape [axis ], arr .shape [axis + 1 :]
703- si , j , sk = low .shape [:axis ], low .shape [axis ], low .shape [axis + 1 :]
704- mk = max (nk , sk )
705- mi = max (ni , si )
706- out = np .empty (mi + (j ,) + mk )
707- for ki in np .ndindex (ni ):
708- for kk in np .ndindex (mk ):
709- ak = tuple (np .mod (kk , nk ))
710- ik = tuple (np .mod (kk , sk ))
711- ai = tuple (np .mod (ki , ni ))
712- ii = tuple (np .mod (ki , si ))
713- a_1d = arr [ai + np .s_ [:, ] + ak ]
714- out_1d = out [ki + np .s_ [:, ] + kk ]
715- low_1d = low [ii + np .s_ [:, ] + ik ]
716- high_1d = high [ii + np .s_ [:, ] + ik ]
717-
718- for r in range (j ):
719- out_1d [r ] = func (a_1d [low_1d [r ]:high_1d [r ]])
702+
703+ def apply_func (vector , l , h ):
704+ return func (vector [l :h ])
705+
706+ apply_func_1d = np .vectorize (apply_func , signature = '(n), (), ()->()' )
707+ vectorized_func = np .vectorize (apply_func_1d ,
708+ signature = '(n), (k), (k)->(m)' )
709+
710+ # Put `axis` at the innermost dimension
711+ dims = list (range (arr .ndim ))
712+ dims [- 1 ] = axis
713+ dims [axis ] = arr .ndim - 1
714+ t_arr = np .transpose (arr , axes = dims )
715+ t_low = np .transpose (low , axes = dims )
716+ t_high = np .transpose (high , axes = dims )
717+
718+ t_out = vectorized_func (t_arr , t_low , t_high )
719+
720+ # Replace `axis` at its place
721+ out = np .transpose (t_out , axes = dims )
720722 return out
721723
722724 def check_gaussian_windowed (self , shape , indice_shape , axis ,
@@ -797,10 +799,6 @@ def test_windowed_mean_graph(self):
797799 def test_windowed_variance (self ):
798800 self .check_windowed (func = sample_stats .windowed_variance , numpy_func = np .var )
799801
800- def test_windowed_variance_graph (self ):
801- func = tf .function (sample_stats .windowed_variance )
802- self .check_windowed (func = func , numpy_func = np .var )
803-
804802
805803@test_util .test_all_tf_execution_regimes
806804class LogAverageProbsTest (test_util .TestCase ):
0 commit comments