13
13
# limitations under the License.
14
14
15
15
import logging
16
+
16
17
from absl .testing import absltest
18
+ from absl .testing import parameterized
17
19
from absl .testing .absltest import mock
18
20
import numpy as np
19
-
20
21
from tensorflow_privacy .privacy .privacy_tests .membership_inference_attack .data_structures import AttackInputData
21
22
from tensorflow_privacy .privacy .privacy_tests .membership_inference_attack .data_structures import SingleSliceSpec
22
23
from tensorflow_privacy .privacy .privacy_tests .membership_inference_attack .data_structures import SlicingFeature
@@ -38,7 +39,7 @@ def _are_lists_equal(lhs, rhs) -> bool:
38
39
return True
39
40
40
41
41
- class SingleSliceSpecsTest (absltest .TestCase ):
42
+ class SingleSliceSpecsTest (parameterized .TestCase ):
42
43
"""Tests for get_single_slice_specs."""
43
44
44
45
ENTIRE_DATASET_SLICE = SingleSliceSpec ()
@@ -95,8 +96,81 @@ def test_slicing_by_multiple_features(self):
95
96
output = get_single_slice_specs (input_data , n_classes )
96
97
self .assertLen (output , expected_slices )
97
98
99
+ @parameterized .parameters (
100
+ (np .array ([1 , 2 , 1 , 2 ]), np .array ([2 , 2 , 1 , 2 ]), [1 , 2 ]),
101
+ (np .array ([0 , - 1 , 2 , - 1 , 2 ]), np .array ([2 , 2 , - 1 , 2 ]), [- 1 , 2 ]),
102
+ (np .array ([1 , 2 , 1 , 2 ] + list (range (5000 ))), np .array ([2 , 2 , 1 ]), [1 , 2 ]),
103
+ (np .array ([1 , 2 , 1 , 2 ]), np .array ([3 , 4 ]), []),
104
+ )
105
+ def test_slicing_by_custom_indices_one_pair (self , custom_train_indices ,
106
+ custom_test_indices ,
107
+ expected_groups ):
108
+ input_data = SlicingSpec (
109
+ all_custom_train_indices = [custom_train_indices ],
110
+ all_custom_test_indices = [custom_test_indices ])
111
+ expected = [self .ENTIRE_DATASET_SLICE ] + [
112
+ SingleSliceSpec (SlicingFeature .CUSTOM ,
113
+ (custom_train_indices , custom_test_indices , g ))
114
+ for g in expected_groups
115
+ ]
116
+ output = get_single_slice_specs (input_data )
117
+ self .assertTrue (_are_lists_equal (output , expected ))
98
118
99
- class GetSliceTest (absltest .TestCase ):
119
+ def test_slicing_by_custom_indices_multi_pairs (self ):
120
+ all_custom_train_indices = [
121
+ np .array ([1 , 2 , 1 , 2 ]),
122
+ np .array ([0 , - 1 , 2 , - 1 , 2 ]),
123
+ np .array ([1 , 2 , 1 , 2 ] + list (range (5000 ))),
124
+ np .array ([1 , 2 , 1 , 2 ])
125
+ ]
126
+ all_custom_test_indices = [
127
+ np .array ([2 , 2 , 1 , 2 ]),
128
+ np .array ([2 , 2 , - 1 , 2 ]),
129
+ np .array ([2 , 2 , 1 ]),
130
+ np .array ([3 , 4 ])
131
+ ]
132
+ expected_group_values = [[1 , 2 ], [- 1 , 2 ], [1 , 2 ], []]
133
+
134
+ input_data = SlicingSpec (
135
+ all_custom_train_indices = all_custom_train_indices ,
136
+ all_custom_test_indices = all_custom_test_indices )
137
+ expected = [self .ENTIRE_DATASET_SLICE ]
138
+ for custom_train_indices , custom_test_indices , eg in zip (
139
+ all_custom_train_indices , all_custom_test_indices ,
140
+ expected_group_values ):
141
+ expected .extend ([
142
+ SingleSliceSpec (SlicingFeature .CUSTOM ,
143
+ (custom_train_indices , custom_test_indices , g ))
144
+ for g in eg
145
+ ])
146
+ output = get_single_slice_specs (input_data )
147
+ self .assertTrue (_are_lists_equal (output , expected ))
148
+
149
+ @parameterized .parameters (
150
+ ([np .array ([1 , 2 ])], None ),
151
+ (None , [np .array ([1 , 2 ])]),
152
+ ([], [np .array ([1 , 2 ])]),
153
+ ([np .array ([1 , 2 ])], [np .array ([1 , 2 ]),
154
+ np .array ([1 , 2 ])]),
155
+ )
156
+ def test_slicing_by_custom_indices_wrong_indices (self ,
157
+ all_custom_train_indices ,
158
+ all_custom_test_indices ):
159
+ self .assertRaises (
160
+ ValueError ,
161
+ SlicingSpec ,
162
+ all_custom_train_indices = all_custom_train_indices ,
163
+ all_custom_test_indices = all_custom_test_indices )
164
+
165
+ def test_slicing_by_custom_indices_too_many_groups (self ):
166
+ input_data = SlicingSpec (
167
+ all_custom_train_indices = [np .arange (1001 ),
168
+ np .arange (3 )],
169
+ all_custom_test_indices = [np .arange (1001 ), np .arange (3 )])
170
+ self .assertRaises (ValueError , get_single_slice_specs , input_data )
171
+
172
+
173
+ class GetSliceTest (parameterized .TestCase ):
100
174
101
175
def __init__ (self , methodname ):
102
176
"""Initialize the test class."""
@@ -210,6 +284,40 @@ def test_slice_by_correctness(self):
210
284
self .assertTrue ((output .labels_train == [0 , 2 ]).all ())
211
285
self .assertTrue ((output .labels_test == [1 , 2 , 0 ]).all ())
212
286
287
+ def test_slice_by_custom_indices (self ):
288
+ custom_train_indices = np .array ([2 , 2 , 100 , 4 ])
289
+ custom_test_indices = np .array ([100 , 2 , 2 , 2 ])
290
+ custom_slice = SingleSliceSpec (
291
+ SlicingFeature .CUSTOM , (custom_train_indices , custom_test_indices , 2 ))
292
+ output = get_slice (self .input_data , custom_slice )
293
+ np .testing .assert_array_equal (output .logits_train ,
294
+ np .array ([[0 , 1 , 0 ], [2 , 0 , 3 ]]))
295
+ np .testing .assert_array_equal (
296
+ output .logits_test , np .array ([[12 , 13 , 0 ], [14 , 15 , 0 ], [0 , 16 , 17 ]]))
297
+ np .testing .assert_array_equal (output .probs_train ,
298
+ np .array ([[0 , 1 , 0 ], [0.1 , 0 , 0.7 ]]))
299
+ np .testing .assert_array_equal (
300
+ output .probs_test , np .array ([[0.1 , 0.9 , 0 ], [0.15 , 0.85 , 0 ], [0 , 0 ,
301
+ 1 ]]))
302
+ np .testing .assert_array_equal (output .labels_train , np .array ([1 , 0 ]))
303
+ np .testing .assert_array_equal (output .labels_test , np .array ([2 , 0 , 2 ]))
304
+ np .testing .assert_array_equal (output .loss_train , np .array ([2 , 0.25 ]))
305
+ np .testing .assert_array_equal (output .loss_test , np .array ([3.5 , 7 , 4.5 ]))
306
+ np .testing .assert_array_equal (output .entropy_train , np .array ([0.4 , 8 ]))
307
+ np .testing .assert_array_equal (output .entropy_test ,
308
+ np .array ([10.5 , 4.5 , 0.3 ]))
309
+
310
+ @parameterized .parameters (
311
+ (np .array ([2 , 2 , 100 ]), np .array ([100 , 2 , 2 ])),
312
+ (np .array ([2 , 2 , 100 , 4 ]), np .array ([100 , 2 , 2 ])),
313
+ (np .array ([2 , 100 , 4 ]), np .array ([100 , 2 , 2 , 2 ])),
314
+ )
315
+ def test_slice_by_custom_indices_wrong_size (self , custom_train_indices ,
316
+ custom_test_indices ):
317
+ custom_slice = SingleSliceSpec (
318
+ SlicingFeature .CUSTOM , (custom_train_indices , custom_test_indices , 2 ))
319
+ self .assertRaises (ValueError , get_slice , self .input_data , custom_slice )
320
+
213
321
214
322
class GetSliceTestForMultilabelData (absltest .TestCase ):
215
323
@@ -288,6 +396,26 @@ def test_slice_by_correctness_fails(self):
288
396
False )
289
397
self .assertRaises (ValueError , get_slice , self .input_data , percentile_slice )
290
398
399
+ def test_slice_by_custom_indices (self ):
400
+ custom_train_indices = np .array ([2 , 2 , 100 , 4 ])
401
+ custom_test_indices = np .array ([100 , 2 , 2 , 2 ])
402
+ custom_slice = SingleSliceSpec (
403
+ SlicingFeature .CUSTOM , (custom_train_indices , custom_test_indices , 2 ))
404
+ output = get_slice (self .input_data , custom_slice )
405
+ # Check logits.
406
+ with self .subTest (msg = 'Check logits' ):
407
+ np .testing .assert_array_equal (output .logits_train ,
408
+ np .array ([[0 , 1 , 0 ], [2 , 0 , 3 ]]))
409
+ np .testing .assert_array_equal (
410
+ output .logits_test , np .array ([[12 , 13 , 0 ], [14 , 15 , 0 ], [0 , 16 , 17 ]]))
411
+
412
+ # Check labels.
413
+ with self .subTest (msg = 'Check labels' ):
414
+ np .testing .assert_array_equal (output .labels_train ,
415
+ np .array ([[0 , 1 , 1 ], [1 , 0 , 1 ]]))
416
+ np .testing .assert_array_equal (output .labels_test ,
417
+ np .array ([[0 , 1 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]))
418
+
291
419
292
420
if __name__ == '__main__' :
293
421
absltest .main ()
0 commit comments