|
15 | 15 | """Tests for Sample Stats Ops.""" |
16 | 16 |
|
17 | 17 | # Dependency imports |
18 | | -import functools |
| 18 | +import itertools |
| 19 | + |
19 | 20 | import numpy as np |
20 | 21 | import tensorflow.compat.v1 as tf1 |
21 | 22 | import tensorflow.compat.v2 as tf |
| 23 | +from absl.testing import parameterized |
| 24 | +from tensorflow.python.framework.errors_impl import InvalidArgumentError |
| 25 | + |
22 | 26 | from tensorflow_probability.python.internal import test_util |
23 | 27 | from tensorflow_probability.python.stats import sample_stats |
24 | 28 |
|
@@ -721,7 +725,8 @@ def apply_func(vector, l, h): |
721 | 725 | out = np.transpose(t_out, axes=dims) |
722 | 726 | return out |
723 | 727 |
|
724 | | - def check_gaussian_windowed(self, shape, indice_shape, axis, |
| 728 | + |
| 729 | + def check_gaussian_windowed_func(self, shape, indice_shape, axis, |
725 | 730 | window_func, np_func): |
726 | 731 | stat_shape = np.array(shape).astype(np.int32) |
727 | 732 | stat_shape[axis] = 1 |
@@ -753,51 +758,56 @@ def check_gaussian_windowed(self, shape, indice_shape, axis, |
753 | 758 | def _make_dynamic_shape(self, x): |
754 | 759 | return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape)) |
755 | 760 |
|
756 | | - def check_windowed(self, func, numpy_func): |
757 | | - check_fn = functools.partial(self.check_gaussian_windowed, |
758 | | - window_func=func, np_func=numpy_func) |
759 | | - check_fn((64, 4, 8), (128, 1, 1), axis=0) |
760 | | - check_fn((64, 4, 8), (32, 1, 1), axis=0) |
761 | | - check_fn((64, 4, 8), (32, 4, 1), axis=0) |
762 | | - check_fn((64, 4, 8), (32, 4, 8), axis=0) |
763 | | - check_fn((64, 4, 8), (64, 4, 8), axis=0) |
764 | | - check_fn((64, 4, 8), (128, 1), axis=0) |
765 | | - check_fn((64, 4, 8), (32,), axis=0) |
766 | | - check_fn((64, 4, 8), (32, 4), axis=0) |
767 | | - |
768 | | - check_fn((64, 4, 8), (64, 64, 1), axis=1) |
769 | | - check_fn((64, 4, 8), (1, 64, 1), axis=1) |
770 | | - check_fn((64, 4, 8), (64, 2, 8), axis=1) |
771 | | - check_fn((64, 4, 8), (64, 4, 8), axis=1) |
772 | | - check_fn((64, 4, 8), (16,), axis=1) |
773 | | - check_fn((64, 4, 8), (1, 64), axis=1) |
774 | | - |
775 | | - check_fn((64, 4, 8), (64, 4, 64), axis=2) |
776 | | - check_fn((64, 4, 8), (1, 1, 64), axis=2) |
777 | | - check_fn((64, 4, 8), (64, 4, 4), axis=2) |
778 | | - check_fn((64, 4, 8), (1, 1, 4), axis=2) |
779 | | - check_fn((64, 4, 8), (64, 4, 8), axis=2) |
780 | | - check_fn((64, 4, 8), (16,), axis=2) |
781 | | - check_fn((64, 4, 8), (1, 4), axis=2) |
782 | | - check_fn((64, 4, 8), (64, 4), axis=2) |
783 | | - |
784 | | - with self.assertRaises(Exception): |
785 | | - # Non broadcastable shapes |
786 | | - check_fn((64, 4, 8), (4, 1, 4), axis=2) |
787 | | - |
788 | | - with self.assertRaises(Exception): |
789 | | - # Non broadcastable shapes |
790 | | - check_fn((64, 4, 8), (2, 4), axis=2) |
791 | | - |
792 | | - def test_windowed_mean(self): |
793 | | - self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean) |
794 | | - |
795 | | - def test_windowed_mean_graph(self): |
796 | | - func = tf.function(sample_stats.windowed_mean) |
797 | | - self.check_windowed(func=func, numpy_func=np.mean) |
798 | | - |
799 | | - def test_windowed_variance(self): |
800 | | - self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var) |
| 761 | + @parameterized.named_parameters(*[( |
| 762 | + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, |
| 763 | + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in |
| 764 | + itertools.product([(64, 4, 8), ], |
| 765 | + [((128, 1, 1), 0), |
| 766 | + ((32, 1, 1), 0), |
| 767 | + ((32, 4, 1), 0), |
| 768 | + ((32, 4, 8), 0), |
| 769 | + ((64, 4, 8), 0), |
| 770 | + ((128, 1), 0), |
| 771 | + ((32,), 0), |
| 772 | + ((32, 4), 0), |
| 773 | +
|
| 774 | + ((64, 64, 1), 1), |
| 775 | + ((1, 64, 1), 1), |
| 776 | + ((64, 2, 8), 1), |
| 777 | + ((64, 4, 8), 1), |
| 778 | + ((16,), 1), |
| 779 | + ((1, 64), 1), |
| 780 | +
|
| 781 | + ((64, 4, 64), 2), |
| 782 | + ((1, 1, 64), 2), |
| 783 | + ((64, 4, 4), 2), |
| 784 | + ((1, 1, 4), 2), |
| 785 | + ((64, 4, 8), 2), |
| 786 | + ((16,), 2), |
| 787 | + ((1, 4), 2), |
| 788 | + ((64, 4), 2)], |
| 789 | + [ |
| 790 | + (sample_stats.windowed_mean, np.mean), |
| 791 | + (sample_stats.windowed_variance, np.var) |
| 792 | + ])]) |
| 793 | + def test_windowed(self, shape, indice_shape, axis, window_func, np_func): |
| 794 | + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, |
| 795 | + np_func) |
| 796 | + |
| 797 | + |
| 798 | + @parameterized.named_parameters(*[( |
| 799 | + f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis, |
| 800 | + tf_func, np_func) for a, (b, axis), (tf_func, np_func) in |
| 801 | + itertools.product([(64, 4, 8), ], |
| 802 | + [((4, 1, 4), 2), ((2, 4), 2)], |
| 803 | + [(sample_stats.windowed_mean, np.mean), |
| 804 | + (sample_stats.windowed_variance, np.var)])]) |
| 805 | + def test_non_broadcastable_shapes(self, shape, indice_shape, axis, |
| 806 | + window_func, np_func): |
| 807 | + with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError), |
| 808 | + '^shape mismatch|Incompatible shapes'): |
| 809 | + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, |
| 810 | + np_func) |
801 | 811 |
|
802 | 812 |
|
803 | 813 | @test_util.test_all_tf_execution_regimes |
|
0 commit comments