Skip to content

Commit 2040f08

Browse files
shs037tensorflower-gardener
authored andcommitted
Allows slicing by custom indices.
PiperOrigin-RevId: 486998645
1 parent ec747a8 commit 2040f08

File tree

5 files changed

+258
-16
lines changed

5 files changed

+258
-16
lines changed

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import logging
2121
import os
2222
import pickle
23-
from typing import Any, Iterable, MutableSequence, Optional, Union
23+
from typing import Any, Iterable, MutableSequence, Optional, Union, Sequence
2424

2525
import numpy as np
2626
import pandas as pd
@@ -39,6 +39,7 @@ class SlicingFeature(enum.Enum):
3939
CLASS = 'class'
4040
PERCENTILE = 'percentile'
4141
CORRECTLY_CLASSIFIED = 'correctly_classified'
42+
CUSTOM = 'custom'
4243

4344

4445
@dataclasses.dataclass
@@ -65,6 +66,11 @@ def __str__(self):
6566
if self.feature == SlicingFeature.PERCENTILE:
6667
return 'Loss percentiles: %d-%d' % self.value
6768

69+
if self.feature == SlicingFeature.CUSTOM:
70+
custom_train_indices, custom_test_indices, group_value = self.value
71+
return (f'Custom indices: train = {custom_train_indices}, '
72+
f'test = {custom_test_indices}, group_value = {group_value}')
73+
6874
return '%s=%s' % (self.feature.name, self.value)
6975

7076

@@ -91,6 +97,37 @@ class SlicingSpec:
9197
# examples will be generated.
9298
by_classification_correctness: bool = False
9399

100+
# When both `all_custom_train_indices` and `all_custom_test_indices` are set,
101+
# will slice by custom indices.
102+
# `custom_train_indices` and `custom_test_indices` are sequences containing
103+
# the same number of arrays. Each array indicates the grouping of training and
104+
# test examples, and should have a length equal to the number of training and
105+
# test examples.
106+
# For example, suppose we have 3 training examples (a1, a2, a3), and
107+
# 2 test examples (b1, b2). Then,
108+
# all_custom_train_indices = [np.array([2, 1, 2]), np.array([0, 0, 1])]
109+
# all_custom_test_indices = [np.array([1, 2]), np.array([1, 0])]
110+
# means we are going to consider two ways of slicing them:
111+
# 1. two groups: (a2, b1) corresponding to value 1, (a1, a3, b2) corresponding
112+
# to value 2.
113+
# 2. two groups: (a1, a2, b2) corresponding to value 0, (a3, b1) corresponding
114+
# to value 1.
115+
all_custom_train_indices: Optional[Sequence[np.ndarray]] = None
116+
all_custom_test_indices: Optional[Sequence[np.ndarray]] = None
117+
118+
def __post_init__(self):
119+
if not self.all_custom_train_indices and not self.all_custom_test_indices:
120+
return
121+
if bool(self.all_custom_train_indices) != bool(
122+
self.all_custom_test_indices):
123+
raise ValueError('custom_train_indices and custom_test_indices must '
124+
'be provided or set to None at the same time.')
125+
if len(self.all_custom_train_indices) != len(self.all_custom_test_indices):
126+
raise ValueError('all_custom_train_indices and all_custom_test_indices '
127+
'should have the same length, but got'
128+
f'{len(self.all_custom_train_indices)} and '
129+
f'{len(self.all_custom_test_indices)}.')
130+
94131
def __str__(self):
95132
"""Only keeps the True values."""
96133
result = ['SlicingSpec(']
@@ -107,6 +144,8 @@ def __str__(self):
107144
result.append(' By percentiles,')
108145
if self.by_classification_correctness:
109146
result.append(' By classification correctness,')
147+
if self.all_custom_train_indices:
148+
result.append(' By custom indices,')
110149
result.append(')')
111150
return '\n'.join(result)
112151

@@ -123,8 +162,9 @@ class AttackType(enum.Enum):
123162
@property
124163
def is_trained_attack(self):
125164
"""Returns whether this type of attack requires training a model."""
126-
return (self != AttackType.THRESHOLD_ATTACK) and (
127-
self != AttackType.THRESHOLD_ENTROPY_ATTACK)
165+
# Compare by name instead of the variable itself to support module reload.
166+
return self.name not in (AttackType.THRESHOLD_ATTACK.name,
167+
AttackType.THRESHOLD_ENTROPY_ATTACK.name)
128168

129169
def __str__(self):
130170
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def testStrEntireDataset(self):
4444
(SlicingFeature.CLASS, 2, 'CLASS=2'),
4545
(SlicingFeature.PERCENTILE, (10, 20), 'Loss percentiles: 10-20'),
4646
(SlicingFeature.CORRECTLY_CLASSIFIED, True, 'CORRECTLY_CLASSIFIED=True'),
47+
(SlicingFeature.CUSTOM, (np.array([1]), np.array([2, 1]), 1),
48+
'Custom indices: train = [1], test = [2 1], group_value = 1'),
4749
)
4850
def testStr(self, feature, value, expected_str):
4951
self.assertEqual(str(SingleSliceSpec(feature, value)), expected_str)

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@
2626
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
2727

2828

29+
_MAX_NUM_OF_SLICES = 1000
30+
31+
2932
def _slice_if_not_none(a, idx):
3033
return None if a is None else a[idx]
3134

3235

3336
def _slice_data_by_indices(data: AttackInputData, idx_train,
3437
idx_test) -> AttackInputData:
35-
"""Slices train fields with with idx_train and test fields with and idx_test."""
38+
"""Slices train fields with idx_train and test fields with idx_test."""
3639

3740
result = AttackInputData()
3841

@@ -128,10 +131,55 @@ def _slice_by_classification_correctness(data: AttackInputData,
128131
return _slice_data_by_indices(data, idx_train, idx_test)
129132

130133

134+
def _slice_by_custom_indices(data: AttackInputData,
135+
custom_train_indices: np.ndarray,
136+
custom_test_indices: np.ndarray,
137+
group_value: int) -> AttackInputData:
138+
"""Slices attack inputs by custom indices.
139+
140+
Args:
141+
data: Data to be used as input to the attack models.
142+
custom_train_indices: The group indices of each training example.
143+
custom_test_indices: The group indices of each test example.
144+
group_value: The group value to pick.
145+
146+
Returns:
147+
AttackInputData object containing the sliced data.
148+
"""
149+
train_size, test_size = data.get_train_size(), data.get_test_size()
150+
if custom_train_indices.shape[0] != train_size:
151+
raise ValueError(
152+
"custom_train_indices should have the same number of elements as "
153+
f"the training data, but got {custom_train_indices.shape} and "
154+
f"{train_size}")
155+
if custom_test_indices.shape[0] != test_size:
156+
raise ValueError(
157+
"custom_test_indices should have the same number of elements as "
158+
f"the test data, but got {custom_test_indices.shape} and "
159+
f"{test_size}")
160+
idx_train = custom_train_indices == group_value
161+
idx_test = custom_test_indices == group_value
162+
return _slice_data_by_indices(data, idx_train, idx_test)
163+
164+
131165
def get_single_slice_specs(
132166
slicing_spec: SlicingSpec,
133167
num_classes: Optional[int] = None) -> List[SingleSliceSpec]:
134-
"""Returns slices of data according to slicing_spec."""
168+
"""Returns slices of data according to slicing_spec.
169+
170+
Args:
171+
slicing_spec: the slicing specification
172+
num_classes: number of classes of the examples. Required when slicing by
173+
class.
174+
175+
Returns:
176+
Slices of data according to the slicing specification.
177+
178+
Raises:
179+
ValueError: If the number of slices is above `_MAX_NUM_OF_SLICES` when
180+
slicing by class or slicing with custom indices. Or, if `num_classes` is
181+
not provided when slicing by class.
182+
"""
135183
result = []
136184

137185
if slicing_spec.entire_dataset:
@@ -141,10 +189,12 @@ def get_single_slice_specs(
141189
by_class = slicing_spec.by_class
142190
if isinstance(by_class, bool):
143191
if by_class:
144-
assert num_classes, "When by_class == True, num_classes should be given."
145-
assert 0 <= num_classes <= 1000, (
146-
f"Too much classes for slicing by classes. "
147-
f"Found {num_classes}.")
192+
if not num_classes:
193+
raise ValueError("When by_class == True, num_classes should be given.")
194+
if not 0 <= num_classes <= _MAX_NUM_OF_SLICES:
195+
raise ValueError(f"Too many classes for slicing by classes. "
196+
f"Found {num_classes}."
197+
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
148198
for c in range(num_classes):
149199
result.append(SingleSliceSpec(SlicingFeature.CLASS, c))
150200
elif isinstance(by_class, int):
@@ -164,6 +214,23 @@ def get_single_slice_specs(
164214
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, True))
165215
result.append(SingleSliceSpec(SlicingFeature.CORRECTLY_CLASSIFIED, False))
166216

217+
# Create slices by custom indices.
218+
if slicing_spec.all_custom_train_indices:
219+
for custom_train_indices, custom_test_indices in zip(
220+
slicing_spec.all_custom_train_indices,
221+
slicing_spec.all_custom_test_indices):
222+
groups = np.intersect1d(
223+
np.unique(custom_train_indices),
224+
np.unique(custom_test_indices),
225+
assume_unique=True)
226+
if not 0 <= groups.size <= _MAX_NUM_OF_SLICES:
227+
raise ValueError(
228+
f"Too many groups ({groups.size}) for slicing by custom indices. "
229+
f"Should be no more than {_MAX_NUM_OF_SLICES}.")
230+
for g in groups:
231+
result.append(
232+
SingleSliceSpec(SlicingFeature.CUSTOM,
233+
(custom_train_indices, custom_test_indices, g)))
167234
return result
168235

169236

@@ -179,6 +246,10 @@ def get_slice(data: AttackInputData,
179246
data_slice = _slice_by_percentiles(data, from_percentile, to_percentile)
180247
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
181248
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
249+
elif slice_spec.feature == SlicingFeature.CUSTOM:
250+
custom_train_indices, custom_test_indices, group_value = slice_spec.value
251+
data_slice = _slice_by_custom_indices(data, custom_train_indices,
252+
custom_test_indices, group_value)
182253
else:
183254
raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)
184255

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515
import logging
16+
1617
from absl.testing import absltest
18+
from absl.testing import parameterized
1719
from absl.testing.absltest import mock
1820
import numpy as np
19-
2021
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
2122
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec
2223
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingFeature
@@ -38,7 +39,7 @@ def _are_lists_equal(lhs, rhs) -> bool:
3839
return True
3940

4041

41-
class SingleSliceSpecsTest(absltest.TestCase):
42+
class SingleSliceSpecsTest(parameterized.TestCase):
4243
"""Tests for get_single_slice_specs."""
4344

4445
ENTIRE_DATASET_SLICE = SingleSliceSpec()
@@ -95,8 +96,81 @@ def test_slicing_by_multiple_features(self):
9596
output = get_single_slice_specs(input_data, n_classes)
9697
self.assertLen(output, expected_slices)
9798

99+
@parameterized.parameters(
100+
(np.array([1, 2, 1, 2]), np.array([2, 2, 1, 2]), [1, 2]),
101+
(np.array([0, -1, 2, -1, 2]), np.array([2, 2, -1, 2]), [-1, 2]),
102+
(np.array([1, 2, 1, 2] + list(range(5000))), np.array([2, 2, 1]), [1, 2]),
103+
(np.array([1, 2, 1, 2]), np.array([3, 4]), []),
104+
)
105+
def test_slicing_by_custom_indices_one_pair(self, custom_train_indices,
106+
custom_test_indices,
107+
expected_groups):
108+
input_data = SlicingSpec(
109+
all_custom_train_indices=[custom_train_indices],
110+
all_custom_test_indices=[custom_test_indices])
111+
expected = [self.ENTIRE_DATASET_SLICE] + [
112+
SingleSliceSpec(SlicingFeature.CUSTOM,
113+
(custom_train_indices, custom_test_indices, g))
114+
for g in expected_groups
115+
]
116+
output = get_single_slice_specs(input_data)
117+
self.assertTrue(_are_lists_equal(output, expected))
98118

99-
class GetSliceTest(absltest.TestCase):
119+
def test_slicing_by_custom_indices_multi_pairs(self):
120+
all_custom_train_indices = [
121+
np.array([1, 2, 1, 2]),
122+
np.array([0, -1, 2, -1, 2]),
123+
np.array([1, 2, 1, 2] + list(range(5000))),
124+
np.array([1, 2, 1, 2])
125+
]
126+
all_custom_test_indices = [
127+
np.array([2, 2, 1, 2]),
128+
np.array([2, 2, -1, 2]),
129+
np.array([2, 2, 1]),
130+
np.array([3, 4])
131+
]
132+
expected_group_values = [[1, 2], [-1, 2], [1, 2], []]
133+
134+
input_data = SlicingSpec(
135+
all_custom_train_indices=all_custom_train_indices,
136+
all_custom_test_indices=all_custom_test_indices)
137+
expected = [self.ENTIRE_DATASET_SLICE]
138+
for custom_train_indices, custom_test_indices, eg in zip(
139+
all_custom_train_indices, all_custom_test_indices,
140+
expected_group_values):
141+
expected.extend([
142+
SingleSliceSpec(SlicingFeature.CUSTOM,
143+
(custom_train_indices, custom_test_indices, g))
144+
for g in eg
145+
])
146+
output = get_single_slice_specs(input_data)
147+
self.assertTrue(_are_lists_equal(output, expected))
148+
149+
@parameterized.parameters(
150+
([np.array([1, 2])], None),
151+
(None, [np.array([1, 2])]),
152+
([], [np.array([1, 2])]),
153+
([np.array([1, 2])], [np.array([1, 2]),
154+
np.array([1, 2])]),
155+
)
156+
def test_slicing_by_custom_indices_wrong_indices(self,
157+
all_custom_train_indices,
158+
all_custom_test_indices):
159+
self.assertRaises(
160+
ValueError,
161+
SlicingSpec,
162+
all_custom_train_indices=all_custom_train_indices,
163+
all_custom_test_indices=all_custom_test_indices)
164+
165+
def test_slicing_by_custom_indices_too_many_groups(self):
166+
input_data = SlicingSpec(
167+
all_custom_train_indices=[np.arange(1001),
168+
np.arange(3)],
169+
all_custom_test_indices=[np.arange(1001), np.arange(3)])
170+
self.assertRaises(ValueError, get_single_slice_specs, input_data)
171+
172+
173+
class GetSliceTest(parameterized.TestCase):
100174

101175
def __init__(self, methodname):
102176
"""Initialize the test class."""
@@ -210,6 +284,40 @@ def test_slice_by_correctness(self):
210284
self.assertTrue((output.labels_train == [0, 2]).all())
211285
self.assertTrue((output.labels_test == [1, 2, 0]).all())
212286

287+
def test_slice_by_custom_indices(self):
288+
custom_train_indices = np.array([2, 2, 100, 4])
289+
custom_test_indices = np.array([100, 2, 2, 2])
290+
custom_slice = SingleSliceSpec(
291+
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
292+
output = get_slice(self.input_data, custom_slice)
293+
np.testing.assert_array_equal(output.logits_train,
294+
np.array([[0, 1, 0], [2, 0, 3]]))
295+
np.testing.assert_array_equal(
296+
output.logits_test, np.array([[12, 13, 0], [14, 15, 0], [0, 16, 17]]))
297+
np.testing.assert_array_equal(output.probs_train,
298+
np.array([[0, 1, 0], [0.1, 0, 0.7]]))
299+
np.testing.assert_array_equal(
300+
output.probs_test, np.array([[0.1, 0.9, 0], [0.15, 0.85, 0], [0, 0,
301+
1]]))
302+
np.testing.assert_array_equal(output.labels_train, np.array([1, 0]))
303+
np.testing.assert_array_equal(output.labels_test, np.array([2, 0, 2]))
304+
np.testing.assert_array_equal(output.loss_train, np.array([2, 0.25]))
305+
np.testing.assert_array_equal(output.loss_test, np.array([3.5, 7, 4.5]))
306+
np.testing.assert_array_equal(output.entropy_train, np.array([0.4, 8]))
307+
np.testing.assert_array_equal(output.entropy_test,
308+
np.array([10.5, 4.5, 0.3]))
309+
310+
@parameterized.parameters(
311+
(np.array([2, 2, 100]), np.array([100, 2, 2])),
312+
(np.array([2, 2, 100, 4]), np.array([100, 2, 2])),
313+
(np.array([2, 100, 4]), np.array([100, 2, 2, 2])),
314+
)
315+
def test_slice_by_custom_indices_wrong_size(self, custom_train_indices,
316+
custom_test_indices):
317+
custom_slice = SingleSliceSpec(
318+
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
319+
self.assertRaises(ValueError, get_slice, self.input_data, custom_slice)
320+
213321

214322
class GetSliceTestForMultilabelData(absltest.TestCase):
215323

@@ -288,6 +396,26 @@ def test_slice_by_correctness_fails(self):
288396
False)
289397
self.assertRaises(ValueError, get_slice, self.input_data, percentile_slice)
290398

399+
def test_slice_by_custom_indices(self):
400+
custom_train_indices = np.array([2, 2, 100, 4])
401+
custom_test_indices = np.array([100, 2, 2, 2])
402+
custom_slice = SingleSliceSpec(
403+
SlicingFeature.CUSTOM, (custom_train_indices, custom_test_indices, 2))
404+
output = get_slice(self.input_data, custom_slice)
405+
# Check logits.
406+
with self.subTest(msg='Check logits'):
407+
np.testing.assert_array_equal(output.logits_train,
408+
np.array([[0, 1, 0], [2, 0, 3]]))
409+
np.testing.assert_array_equal(
410+
output.logits_test, np.array([[12, 13, 0], [14, 15, 0], [0, 16, 17]]))
411+
412+
# Check labels.
413+
with self.subTest(msg='Check labels'):
414+
np.testing.assert_array_equal(output.labels_train,
415+
np.array([[0, 1, 1], [1, 0, 1]]))
416+
np.testing.assert_array_equal(output.labels_test,
417+
np.array([[0, 1, 0], [0, 1, 0], [0, 0, 1]]))
418+
291419

292420
if __name__ == '__main__':
293421
absltest.main()

0 commit comments

Comments
 (0)