From fe606607837cffb61c54bdf94605bf27a8288d4a Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Mon, 7 Apr 2025 13:02:44 -0700 Subject: [PATCH 1/3] Add `cardinality_rule='match'` option to `RegexGenerator` (#971) --- rdt/transformers/id.py | 188 ++++++++++---- tests/integration/transformers/test_id.py | 91 +++++++ tests/unit/transformers/test_id.py | 288 ++++++++++++++++++++++ 3 files changed, 513 insertions(+), 54 deletions(-) diff --git a/rdt/transformers/id.py b/rdt/transformers/id.py index 99adb3f4..f21010b6 100644 --- a/rdt/transformers/id.py +++ b/rdt/transformers/id.py @@ -97,16 +97,19 @@ class RegexGenerator(BaseTransformer): ``regex`` format. Args: - regex (str): + regex_format (str): String representing the regex function. enforce_uniqueness (bool): **DEPRECATED** Whether or not to ensure that the new generated data is all unique. If it isn't possible to create the requested number of rows, then an error will be raised. Defaults to ``None``. cardinality_rule (str): - Rule that the generated data must follow. If set to ``unique``, the generated - data must be unique. If set to ``None``, then the generated data may contain - duplicates. Defaults to ``None``. + Rule that the generated data must follow. + - If set to 'unique', the generated data must be unique. + - If set to 'match', the generated data must have the exact same cardinality + (# of unique values) as the real data. + - If set to ``None``, then the generated data may contain duplicates. + Defaults to ``None``. generation_order (str): String defining how to generate the output. If set to ``alphanumeric``, it will generate the output in alphanumeric order (ie. 'aaa', 'aab' or '1', '2'...). If @@ -154,6 +157,8 @@ def __init__( self.cardinality_rule = _handle_enforce_uniqueness_and_cardinality_rule( enforce_uniqueness, cardinality_rule ) + self._data_cardinality = None + self._unique_regex_values = None self.data_length = None self.generator = None self.generator_size = None @@ -169,6 +174,40 @@ def reset_randomization(self): self.generator, self.generator_size = strings_from_regex(self.regex_format) self.generated = 0 + def _generate_fallback_samples(self, num_samples, template_samples): + """Generate values such that they are all unique, disregarding the regex.""" + try: + # Integer-based fallback: attempt to convert the last template sample to an integer + # and then generate values in a sequential manner. + start = int(template_samples[-1]) + 1 + return [str(i) for i in range(start, start + num_samples)] + + except ValueError: + # String-based fallback: if the integer conversion fails, it uses the template + # samples as a base and appends a counter to make each value unique. + counter = 0 + samples = [] + while num_samples > 0: + samples.extend([f'{i}({counter})' for i in template_samples[:num_samples]]) + num_samples -= len(template_samples) + counter += 1 + + return samples + + def _generate_unique_regex_values(self): + regex_values = [] + try: + while len(regex_values) < self._data_cardinality: + regex_values.append(next(self.generator)) + except (RuntimeError, StopIteration): + fallback_samples = self._generate_fallback_samples( + num_samples=self._data_cardinality - len(regex_values), + template_samples=regex_values, + ) + regex_values.extend(fallback_samples) + + return regex_values + def _fit(self, data): """Fit the transformer to the data. @@ -179,11 +218,15 @@ def _fit(self, data): self.reset_randomization() self.data_length = len(data) + if self.cardinality_rule == 'match': + is_nan = int(pd.isna(data).any()) # nans count as a unique value + self._data_cardinality = data.nunique() + is_nan + def _transform(self, _data): """Drop the input column by returning ``None``.""" return None - def _warn_not_enough_unique_values(self, sample_size, unique_condition): + def _warn_not_enough_unique_values(self, sample_size, unique_condition, match_cardinality): """Warn the user that the regex cannot generate enough unique values. Args: @@ -191,15 +234,18 @@ def _warn_not_enough_unique_values(self, sample_size, unique_condition): Number of samples to be generated. unique_condition (bool): Whether or not to enforce uniqueness. + match_cardinality (bool): + Whether or not to match the cardinality of the data. """ warned = False + warn_msg = ( + f"The regex for '{self.get_input_column()}' can only generate " + f'{int(self.generator_size)} unique values. Additional values may not exactly ' + 'follow the provided regex.' + ) if sample_size > self.generator_size: if unique_condition: - warnings.warn( - f"The regex for '{self.get_input_column()}' can only generate " - f'{self.generator_size} unique values. Additional values may not exactly ' - 'follow the provided regex.' - ) + warnings.warn(warn_msg) warned = True else: LOGGER.info( @@ -218,6 +264,70 @@ def _warn_not_enough_unique_values(self, sample_size, unique_condition): f'values (only {max(remaining, 0)} unique values left).' ) + if match_cardinality: + if self._data_cardinality > sample_size: + warnings.warn( + f'Only {sample_size} values can be generated. Cannot match the cardinality ' + f'of the data, it requires {self._data_cardinality} values.' + ) + if sample_size > self.generator_size and self._data_cardinality > self.generator_size: + warnings.warn(warn_msg) + + def _generate_as_many_as_possible(self, num_samples): + """Generate samples. + + Generate values following the regex until either the ``num_samples`` is reached or + the generator is exhausted. + """ + generated_values = [] + try: + while len(generated_values) < num_samples: + generated_values.append(next(self.generator)) + self.generated += 1 + except (RuntimeError, StopIteration): + pass + + return generated_values + + def _generate_num_samples(self, num_samples, template_samples): + """Generate num_samples values from template_samples. + + Eg: num_samples = 5, template_samples = ['a', 'b'] + The output will be ['a', 'b', 'a', 'b', 'a'] + """ + if num_samples <= 0: + return [] + + repeats = num_samples // len(template_samples) + 1 + return np.tile(template_samples, repeats)[:num_samples].tolist() + + def _generate_match_cardinality(self, num_samples): + """Generate values until the sample size is reached, while respecting the cardinality.""" + template_samples = self._unique_regex_values[:num_samples] + samples = self._generate_num_samples(num_samples - len(template_samples), template_samples) + + return template_samples + samples + + def _generate_samples(self, num_samples, match_cardinality_values, unique_condition): + """Generate samples until the sample size is reached.""" + if match_cardinality_values is not None: + return self._generate_match_cardinality(num_samples) + + # If there aren't enough values left in the generator, reset it + if num_samples > self.generator_size - self.generated: + self.reset_randomization() + + samples = self._generate_as_many_as_possible(num_samples) + num_samples -= len(samples) + if num_samples > 0: + if unique_condition: + new_samples = self._generate_fallback_samples(num_samples, samples) + else: + new_samples = self._generate_num_samples(num_samples, samples) + samples.extend(new_samples) + + return samples + def _reverse_transform(self, data): """Generate new data using the provided ``regex_format``. @@ -230,56 +340,26 @@ def _reverse_transform(self, data): """ if hasattr(self, 'cardinality_rule'): unique_condition = self.cardinality_rule == 'unique' + if self.cardinality_rule == 'match' and self._unique_regex_values is None: + self._unique_regex_values = self._generate_unique_regex_values() else: unique_condition = self.enforce_uniqueness - if data is not None and len(data): - sample_size = len(data) + if hasattr(self, '_unique_regex_values'): + match_cardinality_values = self._unique_regex_values else: - sample_size = self.data_length - - self._warn_not_enough_unique_values(sample_size, unique_condition) - - remaining = self.generator_size - self.generated - if sample_size > remaining: - self.reset_randomization() - remaining = self.generator_size - - generated_values = [] - while len(generated_values) < sample_size: - try: - generated_values.append(next(self.generator)) - self.generated += 1 - except (RuntimeError, StopIteration): - # Can't generate more rows without collision so breaking out of loop - break + match_cardinality_values = None - reverse_transformed = generated_values[:] - - if len(reverse_transformed) < sample_size: - if unique_condition: - try: - remaining_samples = sample_size - len(reverse_transformed) - start = int(generated_values[-1]) + 1 - reverse_transformed.extend([ - str(i) for i in range(start, start + remaining_samples) - ]) - - except ValueError: - counter = 0 - while len(reverse_transformed) < sample_size: - remaining_samples = sample_size - len(reverse_transformed) - reverse_transformed.extend([ - f'{i}({counter})' for i in generated_values[:remaining_samples] - ]) - counter += 1 + if data is not None and len(data): + num_samples = len(data) + else: + num_samples = self.data_length - else: - while len(reverse_transformed) < sample_size: - remaining_samples = sample_size - len(reverse_transformed) - reverse_transformed.extend(generated_values[:remaining_samples]) + match_condition = match_cardinality_values is not None + self._warn_not_enough_unique_values(num_samples, unique_condition, match_condition) + samples = self._generate_samples(num_samples, match_cardinality_values, unique_condition) if getattr(self, 'generation_order', 'alphanumeric') == 'scrambled': - np.random.shuffle(reverse_transformed) + np.random.shuffle(samples) - return np.array(reverse_transformed, dtype=object) + return np.array(samples, dtype=object) diff --git a/tests/integration/transformers/test_id.py b/tests/integration/transformers/test_id.py index 8460ddb9..66d4c45f 100644 --- a/tests/integration/transformers/test_id.py +++ b/tests/integration/transformers/test_id.py @@ -335,6 +335,97 @@ def test_cardinality_rule_not_enough_values_numerical(self): expected = pd.DataFrame({'id': ['2', '3', '4', '5', '6']}, dtype=object) pd.testing.assert_frame_equal(reverse_transform, expected) + def test_cardinality_rule_match(self): + """Test with cardinality_rule='match'.""" + # Setup + data = pd.DataFrame({'id': [1, 2, 3, 4, 5]}) + instance = RegexGenerator('[1-3]{1}', cardinality_rule='match') + + # Run + transformed = instance.fit_transform(data, 'id') + reverse_transform = instance.reverse_transform(transformed) + + # Assert + expected = pd.DataFrame({'id': ['1', '2', '3', '4', '5']}, dtype=object) + pd.testing.assert_frame_equal(reverse_transform, expected) + + def test_cardinality_rule_match_not_enough_values(self): + """Test with cardinality_rule='match' but insufficient regex values.""" + # Setup + data = pd.DataFrame({'id': [1, 2, 3, 4, 5]}) + instance = RegexGenerator('[1-3]{1}', cardinality_rule='match') + + # Run + transformed = instance.fit_transform(data, 'id') + reverse_transform = instance.reverse_transform(transformed) + + # Assert + expected = pd.DataFrame({'id': ['1', '2', '3', '4', '5']}, dtype=object) + pd.testing.assert_frame_equal(reverse_transform, expected) + + def test_called_multiple_times_cardinality_rule_match(self): + """Test calling multiple times when ``cardinality_rule`` is ``match``.""" + # Setup + data = pd.DataFrame({'my_column': [1, 2, 3, 4, 5] * 3}) + generator = RegexGenerator(cardinality_rule='match') + + # Run + transformed_data = generator.fit_transform(data, 'my_column') + first_reverse_transform = generator.reverse_transform(transformed_data.head(3)) + second_reverse_transform = generator.reverse_transform(transformed_data.head(4)) + third_reverse_transform = generator.reverse_transform(transformed_data.head(5)) + fourth_reverse_transform = generator.reverse_transform(transformed_data.head(11)) + + # Assert + expected_first_reverse_transform = pd.DataFrame({'my_column': ['AAAAA', 'AAAAB', 'AAAAC']}) + expected_second_reverse_transform = pd.DataFrame({ + 'my_column': ['AAAAA', 'AAAAB', 'AAAAC', 'AAAAD'] + }) + expected_third_reverse_transform = pd.DataFrame({ + 'my_column': ['AAAAA', 'AAAAB', 'AAAAC', 'AAAAD', 'AAAAE'] + }) + expected_fourth_reverse_transform = pd.DataFrame({ + 'my_column': [ + 'AAAAA', + 'AAAAB', + 'AAAAC', + 'AAAAD', + 'AAAAE', + 'AAAAA', + 'AAAAB', + 'AAAAC', + 'AAAAD', + 'AAAAE', + 'AAAAA', + ] + }) + pd.testing.assert_frame_equal(first_reverse_transform, expected_first_reverse_transform) + pd.testing.assert_frame_equal(second_reverse_transform, expected_second_reverse_transform) + pd.testing.assert_frame_equal(third_reverse_transform, expected_third_reverse_transform) + pd.testing.assert_frame_equal(fourth_reverse_transform, expected_fourth_reverse_transform) + + def test_cardinality_rule_match_empty_regex(self): + """Test with cardinality_rule='match' but insufficient regex values.""" + # Setup + data = pd.DataFrame({'id': [1, 2, 3, 4, 5]}) + instance_unique = RegexGenerator('', cardinality_rule='unique') + instance_match = RegexGenerator('', cardinality_rule='match') + instance_none = RegexGenerator('') + + # Run + transformed_unique = instance_unique.fit_transform(data, 'id') + transformed_match = instance_match.fit_transform(data, 'id') + transformed_none = instance_none.fit_transform(data, 'id') + reverse_transform_unique = instance_unique.reverse_transform(transformed_unique) + reverse_transform_match = instance_match.reverse_transform(transformed_match) + reverse_transform_none = instance_none.reverse_transform(transformed_none) + + # Assert + expected = pd.DataFrame({'id': ['', '', '', '', '']}) + pd.testing.assert_frame_equal(reverse_transform_unique, expected) + pd.testing.assert_frame_equal(reverse_transform_match, expected) + pd.testing.assert_frame_equal(reverse_transform_none, expected) + class TestHyperTransformer: def test_end_to_end_scrambled(self): diff --git a/tests/unit/transformers/test_id.py b/tests/unit/transformers/test_id.py index f8b7f573..a60afa20 100644 --- a/tests/unit/transformers/test_id.py +++ b/tests/unit/transformers/test_id.py @@ -178,6 +178,8 @@ def test___getstate__(self): 'regex_format': '[A-Za-z]{5}', 'random_states': mock_random_sates, 'generation_order': 'alphanumeric', + '_unique_regex_values': None, + '_data_cardinality': None, } @patch('rdt.transformers.id.strings_from_regex') @@ -266,6 +268,21 @@ def test___init__custom(self): assert instance.cardinality_rule == 'unique' assert instance.generation_order == 'scrambled' + def test___init__cardinality_rule_match(self): + """Test it when cardinality_rule is 'match'.""" + # Run + instance = RegexGenerator( + regex_format='[0-9]', + cardinality_rule='match', + ) + + # Assert + assert instance.data_length is None + assert instance.regex_format == '[0-9]' + assert instance.cardinality_rule == 'match' + assert instance._data_cardinality is None + assert instance._unique_regex_values is None + def test___init__bad_value_generation_order(self): """Test that an error is raised if a bad value is given for `generation_order`.""" # Run and Assert @@ -347,6 +364,141 @@ def test__fit(self): assert instance.data_length == 3 assert instance.output_properties == {None: {'next_transformer': None}} + def test__fit_cardinality_rule_match(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match') + columns_data = pd.Series(['1', '2', '3', '2', '1']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 5 + assert instance._data_cardinality == 3 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '1']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['AAAAA', 'AAAAB', 'AAAAC'] + + def test__fit_cardinality_rule_match_with_regex_format(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.Series(['1', '2', '3', '2', '1']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 5 + assert instance._data_cardinality == 3 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_regex_format(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '1']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['1', '2', '3'] + + def test__fit_cardinality_rule_match_with_nans(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.Series(['1', '2', '3', '2', '1', np.nan, None, np.nan]) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 8 + assert instance._data_cardinality == 4 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_nans(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '1', np.nan, None, np.nan]}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['1', '2', '3', '4'] + + def test__fit_cardinality_rule_match_with_nans_too_many_values(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-3]{1}') + columns_data = pd.Series(['1', '2', '3', '2', '4']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 5 + assert instance._data_cardinality == 4 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_nans_too_many_values(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-3]{1}') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '4']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['1', '2', '3', '4'] + + def test__fit_cardinality_rule_match_with_too_many_values_str(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[a-b]{1}') + columns_data = pd.Series(['a', 'b', 'c', 'b', 'd', 'f']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 6 + assert instance._data_cardinality == 5 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_too_many_values_str(self): + """Test it when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[a-b]{1}') + columns_data = pd.DataFrame({'col': ['a', 'b', 'c', 'b', 'd', 'f']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['a', 'b', 'a(0)', 'b(0)', 'a(1)'] + def test__transform(self): """Test the ``_transform`` method. @@ -386,6 +538,7 @@ def test__reverse_transform_generation_order_scrambled(self, shuffle_mock): instance.generator_size = 5 instance.generated = 0 instance.generation_order = 'scrambled' + instance.columns = ['col'] # Run result = instance._reverse_transform(columns_data) @@ -417,6 +570,7 @@ def test__reverse_transform_generator_size_bigger_than_data_length(self): instance.generator = generator instance.generator_size = 5 instance.generated = 0 + instance.columns = ['col'] # Run result = instance._reverse_transform(columns_data) @@ -625,3 +779,137 @@ def test__reverse_transform_info_message(self, mock_logger): expected_args = (6, 'a', 5, 'a') mock_logger.info.assert_called_once_with(expected_format, *expected_args) + + def test__reverse_transform_match_not_enough_values(self): + """Test the case when there are not enough values to match the cardinality rule.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg = re.escape( + 'Only 3 values can be generated. Cannot match the cardinality ' + 'of the data, it requires 5 values.' + ) + with pytest.warns(UserWarning, match=warn_msg): + out = instance._reverse_transform(data[:3]) + + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C'])) + + def test__reverse_transform_match_too_many_samples(self): + """Test it when the number of samples is bigger than the generator size.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-C]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg = re.escape( + "The regex for 'col' can only generate 3 unique values. Additional values may not " + 'exactly follow the provided regex.' + ) + with pytest.warns(UserWarning, match=warn_msg): + out = instance._reverse_transform(data) + + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'A(0)', 'B(0)'])) + + def test__reverse_transform_only_one_warning(self): + """Test it when the num_samples < generator_size but data_cardinality > generator_size.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-C]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg = re.escape( + 'Only 2 values can be generated. Cannot match the cardinality of the data, ' + 'it requires 5 values.' + ) + with pytest.warns(UserWarning, match=warn_msg): + out = instance._reverse_transform(data[:2]) + + np.testing.assert_array_equal(out, np.array(['A', 'B'])) + instance._unique_regex_values = ['A', 'B', 'C', 'A(0)', 'B(0)'] + + def test__reverse_transform_two_warnings(self): + """Test it when data_cardinality > num_samples > generator_size.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-C]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg_1 = re.escape( + 'Only 4 values can be generated. Cannot match the cardinality of the data, ' + 'it requires 5 values.' + ) + warn_msg_2 = re.escape( + "The regex for 'col' can only generate 3 unique values. Additional values may " + 'not exactly follow the provided regex.' + ) + with pytest.warns(UserWarning, match=warn_msg_1): + with pytest.warns(UserWarning, match=warn_msg_2): + out = instance._reverse_transform(data[:4]) + + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'A(0)'])) + + def test__reverse_transform_match_empty_data(self): + """Test it when the data is empty and the cardinality rule is 'match'.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(pd.Series()) + + # Assert + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'D', 'E'])) + + def test__reverse_transform_match_with_nans(self): + """Test it when the data has nans and the cardinality rule is 'match'.""" + # Setup + data = pd.DataFrame({'col': ['A', np.nan, np.nan, 'D', np.nan]}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(data) + + # Assert + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'A', 'B'])) + + def test__reverse_transform_match_too_many_values(self): + """Test it when the data has more values than the cardinality rule.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'B', 'C']}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(pd.Series([1] * 10)) + + # Assert + np.testing.assert_array_equal( + out, np.array(['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A']) + ) + + def test__reverse_transform_no_unique_regex_values_attribute(self): + """Test it without the _unique_regex_values attribute.""" + # Setup + instance = RegexGenerator('[A-E]') + delattr(instance, '_unique_regex_values') + instance.data_length = 6 + generator = AsciiGenerator(5) + instance.generator = generator + instance.generator_size = 5 + instance.generated = 0 + instance.columns = ['a'] + columns_data = pd.Series() + + # Run + out = instance._reverse_transform(columns_data) + + # Assert + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'D', 'E', 'A'])) From 182d43b2ece04f29227c2ef24eec482d15e8d473 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Tue, 8 Apr 2025 17:09:47 -0700 Subject: [PATCH 2/3] Add cardinality_rule='scale' option to RegexGenerator (#973) --- rdt/transformers/id.py | 223 ++++++++++----- rdt/transformers/utils.py | 5 +- tests/integration/transformers/test_id.py | 324 ++++++++++++++++++++++ tests/unit/transformers/test_id.py | 185 +++++++++++- 4 files changed, 655 insertions(+), 82 deletions(-) diff --git a/rdt/transformers/id.py b/rdt/transformers/id.py index f21010b6..06ec9062 100644 --- a/rdt/transformers/id.py +++ b/rdt/transformers/id.py @@ -9,6 +9,7 @@ from rdt.transformers.base import BaseTransformer from rdt.transformers.utils import ( _handle_enforce_uniqueness_and_cardinality_rule, + fill_nan_with_none, strings_from_regex, ) @@ -106,8 +107,10 @@ class RegexGenerator(BaseTransformer): cardinality_rule (str): Rule that the generated data must follow. - If set to 'unique', the generated data must be unique. - - If set to 'match', the generated data must have the exact same cardinality - (# of unique values) as the real data. + - If set to 'match', the generated data will have the exact same cardinality + (number of unique values) as the real data. + - If set to 'scale', the generated data will match the number of repetitions that + each value is allowed to have. - If set to ``None``, then the generated data may contain duplicates. Defaults to ``None``. generation_order (str): @@ -157,25 +160,33 @@ def __init__( self.cardinality_rule = _handle_enforce_uniqueness_and_cardinality_rule( enforce_uniqueness, cardinality_rule ) - self._data_cardinality = None - self._unique_regex_values = None self.data_length = None self.generator = None - self.generator_size = None - self.generated = None if generation_order not in ['alphanumeric', 'scrambled']: raise ValueError("generation_order must be one of 'alphanumeric' or 'scrambled'.") self.generation_order = generation_order + # Used when cardinality_rule is 'scale' + self._data_cardinality_scale = None + self._remaining_samples = {'value': None, 'repetitions': 0} + + # Used when cardinality_rule is 'match' + self._data_cardinality = None + self._unique_regex_values = None + + # Used otherwise + self.generator_size = None + self.generated = None + def reset_randomization(self): """Create a new generator and reset the generated values counter.""" super().reset_randomization() self.generator, self.generator_size = strings_from_regex(self.regex_format) self.generated = 0 - def _generate_fallback_samples(self, num_samples, template_samples): - """Generate values such that they are all unique, disregarding the regex.""" + def _sample_fallback(self, num_samples, template_samples): + """Sample num_samples values such that they are all unique, disregarding the regex.""" try: # Integer-based fallback: attempt to convert the last template sample to an integer # and then generate values in a sequential manner. @@ -187,26 +198,21 @@ def _generate_fallback_samples(self, num_samples, template_samples): # samples as a base and appends a counter to make each value unique. counter = 0 samples = [] - while num_samples > 0: + while num_samples > len(samples): samples.extend([f'{i}({counter})' for i in template_samples[:num_samples]]) - num_samples -= len(template_samples) counter += 1 - return samples + return samples[:num_samples] - def _generate_unique_regex_values(self): - regex_values = [] - try: - while len(regex_values) < self._data_cardinality: - regex_values.append(next(self.generator)) - except (RuntimeError, StopIteration): - fallback_samples = self._generate_fallback_samples( - num_samples=self._data_cardinality - len(regex_values), - template_samples=regex_values, - ) - regex_values.extend(fallback_samples) + def _get_cardinality_frequency(self, data): + """Get number of repetitions and their frequencies.""" + value_counts = data.value_counts(dropna=False) + repetition_counts = value_counts.value_counts().sort_index() + total = repetition_counts.sum() + frequencies = (repetition_counts / total).tolist() + repetitions = repetition_counts.index.tolist() - return regex_values + return repetitions, frequencies def _fit(self, data): """Fit the transformer to the data. @@ -218,9 +224,17 @@ def _fit(self, data): self.reset_randomization() self.data_length = len(data) - if self.cardinality_rule == 'match': - is_nan = int(pd.isna(data).any()) # nans count as a unique value - self._data_cardinality = data.nunique() + is_nan + if hasattr(self, 'cardinality_rule'): + data = fill_nan_with_none(data) + if self.cardinality_rule == 'match': + self._data_cardinality = data.nunique(dropna=False) + + elif self.cardinality_rule == 'scale': + sorted_counts, sorted_frequencies = self._get_cardinality_frequency(data) + self._data_cardinality_scale = { + 'num_repetitions': sorted_counts, + 'frequency': sorted_frequencies, + } def _transform(self, _data): """Drop the input column by returning ``None``.""" @@ -273,61 +287,146 @@ def _warn_not_enough_unique_values(self, sample_size, unique_condition, match_ca if sample_size > self.generator_size and self._data_cardinality > self.generator_size: warnings.warn(warn_msg) - def _generate_as_many_as_possible(self, num_samples): + def _sample_from_generator(self, num_samples): """Generate samples. - Generate values following the regex until either the ``num_samples`` is reached or + Generate values following the regex until either the sample size is reached or the generator is exhausted. """ - generated_values = [] + samples = [] try: - while len(generated_values) < num_samples: - generated_values.append(next(self.generator)) + while len(samples) < num_samples: + samples.append(next(self.generator)) self.generated += 1 except (RuntimeError, StopIteration): pass - return generated_values + return samples - def _generate_num_samples(self, num_samples, template_samples): - """Generate num_samples values from template_samples. + def _sample_from_template(self, num_samples, template_samples): + """Sample num_samples values from template_samples in a cycle. Eg: num_samples = 5, template_samples = ['a', 'b'] The output will be ['a', 'b', 'a', 'b', 'a'] """ - if num_samples <= 0: - return [] - repeats = num_samples // len(template_samples) + 1 return np.tile(template_samples, repeats)[:num_samples].tolist() - def _generate_match_cardinality(self, num_samples): - """Generate values until the sample size is reached, while respecting the cardinality.""" - template_samples = self._unique_regex_values[:num_samples] - samples = self._generate_num_samples(num_samples - len(template_samples), template_samples) + def _sample_match(self, num_samples): + """Sample num_samples values following the 'match' cardinality rule.""" + samples = self._unique_regex_values[:num_samples] + if num_samples > len(samples): + new_samples = self._sample_from_template(num_samples - len(samples), samples) + samples.extend(new_samples) + + return samples - return template_samples + samples + def _sample_repetitions(self, num_samples, value): + """Sample a number of repetitions for a given value.""" + repetitions = np.random.choice( + self._data_cardinality_scale['num_repetitions'], + p=self._data_cardinality_scale['frequency'], + ) + if repetitions <= num_samples: + samples = [value] * repetitions + else: + samples = [value] * num_samples + self._remaining_samples['repetitions'] = repetitions - num_samples + self._remaining_samples['value'] = value - def _generate_samples(self, num_samples, match_cardinality_values, unique_condition): - """Generate samples until the sample size is reached.""" - if match_cardinality_values is not None: - return self._generate_match_cardinality(num_samples) + return samples + + def _sample_scale_fallback(self, num_samples, template_samples): + """Sample num_samples values, disregarding the regex, for the cardinality rule 'scale'.""" + warnings.warn( + f"The regex for '{self.get_input_column()}' cannot generate enough samples. " + 'Additional values may not exactly follow the provided regex.' + ) + samples = [] + fallback_samples = self._sample_fallback(num_samples, template_samples) + while num_samples > len(samples): + fallback_sample = fallback_samples.pop(0) + new_samples = self._sample_repetitions(num_samples - len(samples), fallback_sample) + samples.extend(new_samples) + + return samples + + def _sample_repetitions_from_generator(self, num_samples): + """Sample num_samples values from the generator, or until the generator is exhausted.""" + samples = [self._remaining_samples['value']] * self._remaining_samples['repetitions'] + self._remaining_samples['repetitions'] = 0 + + template_samples = [] + while num_samples > len(samples): + try: + value = next(self.generator) + template_samples.append(value) + except (RuntimeError, StopIteration): + # If the generator is exhausted and no samples have been generated yet, reset it + if len(template_samples) == 0: + self.reset_randomization() + continue + else: + break + + new_samples = self._sample_repetitions(num_samples - len(samples), value) + samples.extend(new_samples) + + return samples, template_samples + + def _sample_scale(self, num_samples): + """Sample num_samples values following the 'scale' cardinality rule.""" + if self._remaining_samples['repetitions'] > num_samples: + self._remaining_samples['repetitions'] -= num_samples + return [self._remaining_samples['value']] * num_samples + + samples, template_samples = self._sample_repetitions_from_generator(num_samples) + if num_samples > len(samples): + new_samples = self._sample_scale_fallback(num_samples - len(samples), template_samples) + samples.extend(new_samples) + + return samples + + def _sample(self, num_samples, unique_condition): + """Sample num_samples values.""" + if num_samples <= 0: + return [] + + if hasattr(self, 'cardinality_rule'): + if self.cardinality_rule == 'match': + return self._sample_match(num_samples) + + if self.cardinality_rule == 'scale': + return self._sample_scale(num_samples) # If there aren't enough values left in the generator, reset it if num_samples > self.generator_size - self.generated: self.reset_randomization() - samples = self._generate_as_many_as_possible(num_samples) - num_samples -= len(samples) - if num_samples > 0: + samples = self._sample_from_generator(num_samples) + if num_samples > len(samples): if unique_condition: - new_samples = self._generate_fallback_samples(num_samples, samples) + new_samples = self._sample_fallback(num_samples - len(samples), samples) else: - new_samples = self._generate_num_samples(num_samples, samples) + new_samples = self._sample_from_template(num_samples - len(samples), samples) samples.extend(new_samples) return samples + def _generate_unique_regexes(self): + regex_values = [] + try: + while len(regex_values) < self._data_cardinality: + regex_values.append(next(self.generator)) + except (RuntimeError, StopIteration): + fallback_samples = self._sample_fallback( + num_samples=self._data_cardinality - len(regex_values), + template_samples=regex_values, + ) + regex_values.extend(fallback_samples) + + return regex_values + def _reverse_transform(self, data): """Generate new data using the provided ``regex_format``. @@ -340,24 +439,16 @@ def _reverse_transform(self, data): """ if hasattr(self, 'cardinality_rule'): unique_condition = self.cardinality_rule == 'unique' - if self.cardinality_rule == 'match' and self._unique_regex_values is None: - self._unique_regex_values = self._generate_unique_regex_values() + match_cardinality = self.cardinality_rule == 'match' + if match_cardinality and self._unique_regex_values is None: + self._unique_regex_values = self._generate_unique_regexes() else: unique_condition = self.enforce_uniqueness + match_cardinality = False - if hasattr(self, '_unique_regex_values'): - match_cardinality_values = self._unique_regex_values - else: - match_cardinality_values = None - - if data is not None and len(data): - num_samples = len(data) - else: - num_samples = self.data_length - - match_condition = match_cardinality_values is not None - self._warn_not_enough_unique_values(num_samples, unique_condition, match_condition) - samples = self._generate_samples(num_samples, match_cardinality_values, unique_condition) + num_samples = len(data) if (data is not None and len(data)) else self.data_length + self._warn_not_enough_unique_values(num_samples, unique_condition, match_cardinality) + samples = self._sample(num_samples, unique_condition) if getattr(self, 'generation_order', 'alphanumeric') == 'scrambled': np.random.shuffle(samples) diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index c4e08b20..9de36d48 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -372,7 +372,6 @@ def __getitem__(self, sdtype): def _handle_enforce_uniqueness_and_cardinality_rule(enforce_uniqueness, cardinality_rule): - result = cardinality_rule if enforce_uniqueness is not None: warnings.warn( "The 'enforce_uniqueness' parameter is no longer supported. " @@ -380,6 +379,6 @@ def _handle_enforce_uniqueness_and_cardinality_rule(enforce_uniqueness, cardinal FutureWarning, ) if enforce_uniqueness and cardinality_rule is None: - result = 'unique' + return 'unique' - return result + return cardinality_rule diff --git a/tests/integration/transformers/test_id.py b/tests/integration/transformers/test_id.py index 66d4c45f..f3cb3fb5 100644 --- a/tests/integration/transformers/test_id.py +++ b/tests/integration/transformers/test_id.py @@ -1,4 +1,5 @@ import pickle +import warnings import numpy as np import pandas as pd @@ -426,6 +427,329 @@ def test_cardinality_rule_match_empty_regex(self): pd.testing.assert_frame_equal(reverse_transform_match, expected) pd.testing.assert_frame_equal(reverse_transform_none, expected) + def test_cardinality_rule_scale(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + assert set(out['col']).issubset({'a', 'b', 'c'}) + + value_counts = out['col'].value_counts() + assert value_counts['a'] in {50, 100, 150} + assert value_counts.get('b', 0) in {0, 50, 100} + assert value_counts.get('c', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test_cardinality_rule_scale_nans(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': [np.nan] * 50 + ['B'] * 100}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + assert set(out['col']).issubset({'a', 'b', 'c'}) + + value_counts = out['col'].value_counts() + assert value_counts['a'] in {50, 100, 150} + assert value_counts.get('b', 0) in {0, 50, 100} + assert value_counts.get('c', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test_cardinality_rule_scale_one_value(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + pd.testing.assert_frame_equal(out, data) + + def test_cardinality_rule_scale_one_value_many_transform(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + instance.fit_transform(data, 'col') + out = instance.reverse_transform(pd.DataFrame(index=range(200))) + + assert len(recorded_warnings) == 0 + + # Assert + expected = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50 + ['D'] * 50}) + pd.testing.assert_frame_equal(out, expected) + + def test_cardinality_rule_scale_empty_data(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': []}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + pd.testing.assert_frame_equal(out, data, check_dtype=False) + + def test_cardinality_rule_scale_proportions(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[a-z]{3}', cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert 900 < one_count < 1100 + assert 400 < two_count < 600 + assert 233 < three_count < 433 + assert len(out) == 3000 + assert more_count == 0 + + def test_cardinality_rule_scale_not_enough_regex_categorical(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert 900 < one_count < 1100 + assert 400 < two_count < 600 + assert 233 < three_count < 433 + assert len(out) == 3000 + assert more_count == 0 + + def assert_proportions(self, out, samples): + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert np.isclose(one_count, two_count * 2, atol=samples * 0.2) + assert np.isclose(one_count, three_count * 3, atol=samples * 0.2) + assert len(out) == samples + assert more_count == 0 + + def test_cardinality_rule_scale_not_enough_regex_numerical(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[1-3]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 1 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + assert warn_msg == str(recorded_warnings[0].message) + + # Assert + self.assert_proportions(out, 3000) + + def test_cardinality_rule_scale_called_multiple_times(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(cardinality_rule='scale', generation_order='alphanumeric') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + assert len(recorded_warnings) == 0 + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + + first_set = set(first_reverse_transform['col']) + second_set = set(second_reverse_transform['col']) + third_set = set(third_reverse_transform['col']) + assert first_set.isdisjoint(set(second_reverse_transform['col'][200:])) + assert first_set.isdisjoint(set(third_reverse_transform['col'][200:])) + assert first_set.isdisjoint(set(fourth_reverse_transform['col'][200:])) + assert second_set.isdisjoint(set(third_reverse_transform['col'][200:])) + assert second_set.isdisjoint(set(fourth_reverse_transform['col'][200:])) + assert third_set.isdisjoint(set(fourth_reverse_transform['col'][200:])) + + def test_cardinality_rule_scale_called_multiple_times_not_enough_regex(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[1-3]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + assert len(recorded_warnings) == 4 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + for warning in recorded_warnings: + assert warn_msg == str(warning.message) + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + + def test_cardinality_rule_scale_called_multiple_times_not_enough_regex_categorical(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + assert len(recorded_warnings) == 4 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + for warning in recorded_warnings: + assert warn_msg == str(warning.message) + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + + def test_cardinality_rule_scale_called_multiple_times_remaining_samples(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + hundred = [i // 100 for i in range(1000)] + two_hundred = [i // 200 for i in range(2000, 3000)] + data = pd.DataFrame({'col': hundred + two_hundred}) + instance = RegexGenerator( + regex_format='[a-f]', cardinality_rule='scale', generation_order='alphanumeric' + ) + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_out = instance.reverse_transform(transformed_data.head(250)) + second_out = instance.reverse_transform(transformed_data.head(3_000)) + + assert len(recorded_warnings) == 1 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + assert warn_msg == str(recorded_warnings[0].message) + + # Assert + assert len(first_out) == 250 + assert len(set(first_out['col'][200:])) == 1 + pd.testing.assert_series_equal( + first_out['col'][200:], + second_out['col'][:50], + check_index=False, + ) + assert second_out['col'][0] not in second_out['col'][50:] + class TestHyperTransformer: def test_end_to_end_scrambled(self): diff --git a/tests/unit/transformers/test_id.py b/tests/unit/transformers/test_id.py index a60afa20..cd689739 100644 --- a/tests/unit/transformers/test_id.py +++ b/tests/unit/transformers/test_id.py @@ -1,6 +1,7 @@ """Test for ID transformers.""" import re +import warnings from string import ascii_uppercase from unittest.mock import Mock, patch @@ -180,6 +181,8 @@ def test___getstate__(self): 'generation_order': 'alphanumeric', '_unique_regex_values': None, '_data_cardinality': None, + '_data_cardinality_scale': None, + '_remaining_samples': {'value': None, 'repetitions': 0}, } @patch('rdt.transformers.id.strings_from_regex') @@ -498,6 +501,59 @@ def test__reverse_transform_cardinality_rule_match_with_too_many_values_str(self # Assert assert instance._unique_regex_values == ['a', 'b', 'a(0)', 'b(0)', 'a(1)'] + assert instance._unique_regex_values == ['a', 'b', 'a(0)', 'b(0)', 'a(1)'] + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__fit_cardinality_rule_scale(self): + """Test it when cardinality_rule is 'scale'.""" + # Setup + instance = RegexGenerator(cardinality_rule='scale') + columns_data = pd.Series(['1', '2', '3', '4', '5', '5', '6', '6', '7', '7', '7', '7']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 12 + assert instance._data_cardinality_scale == { + 'num_repetitions': [1, 2, 4], + 'frequency': [4 / 7, 2 / 7, 1 / 7], + } + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__fit_cardinality_rule_scale_nans(self): + """Test it when cardinality_rule is 'scale'.""" + # Setup + instance = RegexGenerator(cardinality_rule='scale') + columns_data = pd.Series([np.nan, np.nan, None, None, '1', 2]) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 6 + assert instance._data_cardinality_scale == { + 'num_repetitions': [1, 4], + 'frequency': [2 / 3, 1 / 3], + } + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__fit_cardinality_rule_scale_only_nans(self): + """Test it when cardinality_rule is 'scale'.""" + # Setup + instance = RegexGenerator(cardinality_rule='scale') + columns_data = pd.Series([np.nan, np.nan, None, None, float('nan'), float('nan')]) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 6 + assert instance._data_cardinality_scale == { + 'num_repetitions': [6], + 'frequency': [1], + } + assert instance.output_properties == {None: {'next_transformer': None}} def test__transform(self): """Test the ``_transform`` method. @@ -895,21 +951,124 @@ def test__reverse_transform_match_too_many_values(self): out, np.array(['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A']) ) - def test__reverse_transform_no_unique_regex_values_attribute(self): - """Test it without the _unique_regex_values attribute.""" + def test__reverse_transform_scale(self): + """Test when cardinality rule is 'scale'.""" # Setup - instance = RegexGenerator('[A-E]') - delattr(instance, '_unique_regex_values') - instance.data_length = 6 - generator = AsciiGenerator(5) - instance.generator = generator - instance.generator_size = 5 - instance.generated = 0 - instance.columns = ['a'] - columns_data = pd.Series() + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + instance.fit(data, 'col') # Run - out = instance._reverse_transform(columns_data) + out = instance._reverse_transform(data) # Assert - np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'D', 'E', 'A'])) + assert set(out).issubset({'A', 'B', 'C'}) + + value_counts = pd.Series(out).value_counts() + assert value_counts['A'] in {50, 100, 150} + assert value_counts.get('B', 0) in {0, 50, 100} + assert value_counts.get('C', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test__reverse_transform_scale_empty_data(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': []}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(data) + + # Assert + assert isinstance(out, np.ndarray) + assert out.size == 0 + + def test__reverse_transform_scale_not_enough_regex(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + out = instance._reverse_transform(data) + + assert len(recorded_warnings) == 1 + assert str(recorded_warnings[0].message) == ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + + # Assert + assert set(out).issubset({'A', 'B', 'A(0)'}) + + value_counts = pd.Series(out).value_counts() + assert value_counts['A'] in {50, 100, 150} + assert value_counts.get('B', 0) in {0, 50, 100} + assert value_counts.get('A(0)', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test__reverse_transform_scale_not_enough_regex_multiple_calls(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + out1 = instance._reverse_transform(data) + out2 = instance._reverse_transform(data) + + assert len(recorded_warnings) == 2 + assert str(recorded_warnings[0].message) == ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + + # Assert + assert set(out1) == {'A', 'B', 'A(0)'} + assert set(out2) == {'A', 'B', 'A(0)'} + + for out in [out1, out2]: + value_counts = pd.Series(out).value_counts() + assert value_counts['A'] in {50, 100, 150} + assert value_counts.get('B', 0) in {0, 50, 100} + assert value_counts.get('A(0)', 0) in {0, 50} + assert value_counts.sum() == 150 + + def test__reverse_transform_scale_remaining_values(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 10 + ['B'] * 3}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(8)) + out2 = instance._reverse_transform(data) + + # Assert + assert set(out1).issubset({'A', 'B', 'A(0)'}) + assert set(out2).issubset({'A', 'B', 'A(0)', 'B(0)', 'B(1)', 'B(2)', 'B(3)'}) + + def test__reverse_transform_scale_many_remaining_values(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 100}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(10)) + out2 = instance._reverse_transform(data.head(10)) + + # Assert + assert np.array_equal(out1, np.array(['A'] * 10)) + assert np.array_equal(out2, np.array(['A'] * 10)) From 079b7da6bbbc621bf674ba514017fafdf58e0cbd Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Wed, 9 Apr 2025 14:57:47 -0700 Subject: [PATCH 3/3] Add `cardinality_rule='scale'` option for `AnonymizedFaker` (#979) --- rdt/transformers/id.py | 48 ++--- rdt/transformers/pii/anonymizer.py | 129 +++++++++---- rdt/transformers/utils.py | 27 +++ .../transformers/pii/test_anonymizer.py | 177 ++++++++++++++++++ tests/integration/transformers/test_id.py | 2 +- .../unit/transformers/pii/test_anonymizer.py | 128 +++++++++++-- 6 files changed, 436 insertions(+), 75 deletions(-) diff --git a/rdt/transformers/id.py b/rdt/transformers/id.py index 06ec9062..1ae3a32a 100644 --- a/rdt/transformers/id.py +++ b/rdt/transformers/id.py @@ -8,7 +8,9 @@ from rdt.transformers.base import BaseTransformer from rdt.transformers.utils import ( + _get_cardinality_frequency, _handle_enforce_uniqueness_and_cardinality_rule, + _sample_repetitions, fill_nan_with_none, strings_from_regex, ) @@ -185,6 +187,10 @@ def reset_randomization(self): self.generator, self.generator_size = strings_from_regex(self.regex_format) self.generated = 0 + if hasattr(self, 'cardinality_rule') and self.cardinality_rule == 'scale': + self._remaining_samples['repetitions'] = 0 + np.random.seed(self.random_seed) + def _sample_fallback(self, num_samples, template_samples): """Sample num_samples values such that they are all unique, disregarding the regex.""" try: @@ -204,16 +210,6 @@ def _sample_fallback(self, num_samples, template_samples): return samples[:num_samples] - def _get_cardinality_frequency(self, data): - """Get number of repetitions and their frequencies.""" - value_counts = data.value_counts(dropna=False) - repetition_counts = value_counts.value_counts().sort_index() - total = repetition_counts.sum() - frequencies = (repetition_counts / total).tolist() - repetitions = repetition_counts.index.tolist() - - return repetitions, frequencies - def _fit(self, data): """Fit the transformer to the data. @@ -230,7 +226,7 @@ def _fit(self, data): self._data_cardinality = data.nunique(dropna=False) elif self.cardinality_rule == 'scale': - sorted_counts, sorted_frequencies = self._get_cardinality_frequency(data) + sorted_counts, sorted_frequencies = _get_cardinality_frequency(data) self._data_cardinality_scale = { 'num_repetitions': sorted_counts, 'frequency': sorted_frequencies, @@ -321,21 +317,6 @@ def _sample_match(self, num_samples): return samples - def _sample_repetitions(self, num_samples, value): - """Sample a number of repetitions for a given value.""" - repetitions = np.random.choice( - self._data_cardinality_scale['num_repetitions'], - p=self._data_cardinality_scale['frequency'], - ) - if repetitions <= num_samples: - samples = [value] * repetitions - else: - samples = [value] * num_samples - self._remaining_samples['repetitions'] = repetitions - num_samples - self._remaining_samples['value'] = value - - return samples - def _sample_scale_fallback(self, num_samples, template_samples): """Sample num_samples values, disregarding the regex, for the cardinality rule 'scale'.""" warnings.warn( @@ -345,8 +326,12 @@ def _sample_scale_fallback(self, num_samples, template_samples): samples = [] fallback_samples = self._sample_fallback(num_samples, template_samples) while num_samples > len(samples): - fallback_sample = fallback_samples.pop(0) - new_samples = self._sample_repetitions(num_samples - len(samples), fallback_sample) + new_samples, self._remaining_samples = _sample_repetitions( + num_samples - len(samples), + fallback_samples.pop(0), + self._data_cardinality_scale.copy(), + self._remaining_samples.copy(), + ) samples.extend(new_samples) return samples @@ -369,7 +354,12 @@ def _sample_repetitions_from_generator(self, num_samples): else: break - new_samples = self._sample_repetitions(num_samples - len(samples), value) + new_samples, self._remaining_samples = _sample_repetitions( + num_samples - len(samples), + value, + self._data_cardinality_scale.copy(), + self._remaining_samples.copy(), + ) samples.extend(new_samples) return samples, template_samples diff --git a/rdt/transformers/pii/anonymizer.py b/rdt/transformers/pii/anonymizer.py index fbd93ec7..74558828 100644 --- a/rdt/transformers/pii/anonymizer.py +++ b/rdt/transformers/pii/anonymizer.py @@ -15,7 +15,11 @@ from rdt.errors import TransformerInputError, TransformerProcessingError from rdt.transformers.base import BaseTransformer from rdt.transformers.categorical import LabelEncoder -from rdt.transformers.utils import _handle_enforce_uniqueness_and_cardinality_rule +from rdt.transformers.utils import ( + _get_cardinality_frequency, + _handle_enforce_uniqueness_and_cardinality_rule, + _sample_repetitions, +) class AnonymizedFaker(BaseTransformer): @@ -38,6 +42,8 @@ class AnonymizedFaker(BaseTransformer): cardinality_rule (str): If ``'unique'`` enforce that every created value is unique. If ``'match'`` match the cardinality of the data seen during fit. + If set to 'scale', the generated data will match the number of repetitions that + each value is allowed to have. If ``None`` do not consider cardinality. Defaults to ``None``. enforce_uniqueness (bool): @@ -118,6 +124,8 @@ def __init__( missing_value_generation='random', ): super().__init__() + self._data_cardinality_scale = None + self._remaining_samples = {'value': None, 'repetitions': 0} self._data_cardinality = None self.data_length = None self.enforce_uniqueness = enforce_uniqueness @@ -182,10 +190,14 @@ def reset_randomization(self): self.faker = faker.Faker(self.locales) self.faker.seed_instance(self._faker_random_seed) + if hasattr(self, 'cardinality_rule') and self.cardinality_rule == 'scale': + self._remaining_samples['repetitions'] = 0 + np.random.seed(self.random_seed) + def _function(self): """Return the result of calling the ``faker`` function.""" try: - if self.cardinality_rule in {'unique', 'match'}: + if self.cardinality_rule in {'unique', 'match', 'scale'}: faker_attr = self.faker.unique else: faker_attr = self.faker @@ -222,45 +234,45 @@ def _fit(self, data): self._set_faker_seed(data) self.data_length = len(data) if self.missing_value_generation == 'random': - self._nan_frequency = data.isna().sum() / len(data) + self._nan_frequency = data.isna().sum() / len(data) if len(data) > 0 else 0.0 if self.cardinality_rule == 'match': # remove nans from data self._data_cardinality = len(data.dropna().unique()) + if self.cardinality_rule == 'scale': + sorted_counts, sorted_frequencies = _get_cardinality_frequency(data.dropna()) + self._data_cardinality_scale = { + 'num_repetitions': sorted_counts, + 'frequency': sorted_frequencies, + } + def _transform(self, _data): """Drop the input column by returning ``None``.""" return None - def _get_unique_categories(self, samples): - return np.array([self._function() for _ in range(samples)], dtype=object) - - def _reverse_transform_cardinality_rule_match(self, sample_size): - """Reverse transform the data when the cardinality rule is 'match'.""" - num_nans = self._calculate_num_nans(sample_size) - reverse_transformed = self._generate_nans(num_nans) - - if sample_size <= num_nans: - return reverse_transformed - - remaining_samples = sample_size - num_nans - sampled_values = self._generate_cardinality_match_values(remaining_samples) - - reverse_transformed = np.concatenate([reverse_transformed, sampled_values]) - np.random.shuffle(reverse_transformed) - - return reverse_transformed - - def _calculate_num_nans(self, sample_size): - """Calculate the number of NaN values to generate.""" - if self.missing_value_generation == 'random': - return int(self._nan_frequency * sample_size) + def _generate_cardinality_scale_values(self, remaining_samples): + """Generate sampled values while ensuring each unique category appears at least once.""" + if self._remaining_samples['repetitions'] >= remaining_samples: + self._remaining_samples['repetitions'] -= remaining_samples + return [self._remaining_samples['value']] * remaining_samples + + samples = [self._remaining_samples['value']] * self._remaining_samples['repetitions'] + self._remaining_samples['repetitions'] = 0 + + while len(samples) < remaining_samples: + new_samples, self._remaining_samples = _sample_repetitions( + remaining_samples - len(samples), + self._function(), + self._data_cardinality_scale.copy(), + self._remaining_samples.copy(), + ) + samples.extend(new_samples) - return 0 + return np.array(samples, dtype=object) - def _generate_nans(self, num_nans): - """Generate an array of NaN values.""" - return np.full(num_nans, np.nan, dtype=object) + def _get_unique_categories(self, samples): + return np.array([self._function() for _ in range(samples)], dtype=object) def _generate_cardinality_match_values(self, remaining_samples): """Generate sampled values while ensuring each unique category appears at least once.""" @@ -295,6 +307,36 @@ def _reverse_transform_with_fallback(self, sample_size): return np.array(reverse_transformed, dtype=object) + def _calculate_num_nans(self, sample_size): + """Calculate the number of NaN values to generate.""" + if self.missing_value_generation == 'random': + return int(self._nan_frequency * sample_size) + + return 0 + + def _generate_nans(self, num_nans): + """Generate an array of NaN values.""" + return np.full(num_nans, np.nan, dtype=object) + + def _reverse_transform_cardinality_rules(self, sample_size): + """Reverse transform the data when the cardinality rule is 'match' or 'scale'.""" + num_nans = self._calculate_num_nans(sample_size) + reverse_transformed = self._generate_nans(num_nans) + + if sample_size <= num_nans: + return reverse_transformed + + remaining_samples = sample_size - num_nans + if self.cardinality_rule == 'match': + sampled_values = self._generate_cardinality_match_values(remaining_samples) + else: + sampled_values = self._generate_cardinality_scale_values(remaining_samples) + + reverse_transformed = np.concatenate([reverse_transformed, sampled_values]) + np.random.shuffle(reverse_transformed) + + return reverse_transformed + def _reverse_transform(self, data): """Generate new anonymized data using a ``faker.provider.function``. @@ -310,19 +352,25 @@ def _reverse_transform(self, data): else: sample_size = self.data_length - if hasattr(self, 'cardinality_rule') and self.cardinality_rule == 'match': - reverse_transformed = self._reverse_transform_cardinality_rule_match(sample_size) + if hasattr(self, 'cardinality_rule') and self.cardinality_rule in {'match', 'scale'}: + reverse_transformed = self._reverse_transform_cardinality_rules(sample_size) else: reverse_transformed = self._reverse_transform_with_fallback(sample_size) - if self.missing_value_generation == 'random' and not pd.isna(reverse_transformed).any(): + if self.missing_value_generation == 'random' and pd.notna(reverse_transformed).all(): num_nans = int(self._nan_frequency * sample_size) nan_indices = np.random.choice(sample_size, num_nans, replace=False) reverse_transformed[nan_indices] = np.nan return reverse_transformed - def _set_fitted_parameters(self, column_name, nan_frequency=0.0, cardinality=None): + def _set_fitted_parameters( + self, + column_name, + nan_frequency=0.0, + cardinality=None, + cardinality_scale=None, + ): """Manually set the parameters on the transformer to get it into a fitted state. Args: @@ -334,6 +382,12 @@ def _set_fitted_parameters(self, column_name, nan_frequency=0.0, cardinality=Non cardinality (int or None): The number of unique values to generate if cardinality rule is set to 'match'. + cardinality_scale (dict or None): + The frequency of each number of repetitions in the data: + { + 'num_repetitions': list of int, + 'frequency': list of float, + } """ self.reset_randomization() self.columns = [column_name] @@ -344,7 +398,14 @@ def _set_fitted_parameters(self, column_name, nan_frequency=0.0, cardinality=Non 'Cardinality "match" rule must specify a cardinality value.' ) + if self.cardinality_rule == 'scale': + if not cardinality_scale: + raise TransformerInputError( + 'Cardinality "scale" rule must specify a cardinality value.' + ) + self._data_cardinality = cardinality + self._data_cardinality_scale = cardinality_scale self._nan_frequency = nan_frequency def __repr__(self): diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index 9de36d48..4382bab9 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -382,3 +382,30 @@ def _handle_enforce_uniqueness_and_cardinality_rule(enforce_uniqueness, cardinal return 'unique' return cardinality_rule + + +def _get_cardinality_frequency(data): + """Get number of repetitions of values in the data and their frequencies.""" + value_counts = data.value_counts(dropna=False) + repetition_counts = value_counts.value_counts().sort_index() + total = repetition_counts.sum() + frequencies = (repetition_counts / total).tolist() + repetitions = repetition_counts.index.tolist() + + return repetitions, frequencies + + +def _sample_repetitions(num_samples, value, data_cardinality_scale, remaining_samples): + """Sample a number of repetitions for a given value.""" + repetitions = np.random.choice( + data_cardinality_scale['num_repetitions'], + p=data_cardinality_scale['frequency'], + ) + if repetitions <= num_samples: + samples = [value] * repetitions + else: + samples = [value] * num_samples + remaining_samples['repetitions'] = repetitions - num_samples + remaining_samples['value'] = value + + return samples, remaining_samples diff --git a/tests/integration/transformers/pii/test_anonymizer.py b/tests/integration/transformers/pii/test_anonymizer.py index 7a100f99..314c072d 100644 --- a/tests/integration/transformers/pii/test_anonymizer.py +++ b/tests/integration/transformers/pii/test_anonymizer.py @@ -318,6 +318,183 @@ def test_anonymized_faker_produces_only_n_values_for_each_reverse_transform_card # Assert assert set(first_reverse_transformed['name']) == set(second_reverse_transformed['name']) + def test_cardinality_rule_scale(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + assert set(out['col']) == {'KAab', 'qOSU'} + + value_counts = out['col'].value_counts() + assert value_counts['KAab'] == 50 + assert value_counts['qOSU'] == 100 + + def test_cardinality_rule_scale_nans(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': [np.nan] * 50 + ['B'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + assert value_counts['MGWz'] == 100 + assert out['col'].isna().sum() == 50 + + def test_cardinality_rule_scale_one_value(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + pd.testing.assert_frame_equal(out, pd.DataFrame({'col': ['qOSU'] * 50})) + + def test_cardinality_rule_scale_one_value_many_transform(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + instance.fit_transform(data, 'col') + out = instance.reverse_transform(pd.DataFrame(index=range(200))) + + # Assert + value_counts = out['col'].value_counts() + assert value_counts['qOSU'] == 50 + assert value_counts['JEWW'] == 50 + assert value_counts['KAab'] == 50 + assert value_counts['CPmg'] == 50 + + def test_cardinality_rule_scale_empty_data(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': []}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + pd.testing.assert_frame_equal(out, data, check_dtype=False) + + def test_cardinality_rule_scale_proportions(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert 900 < one_count < 1100 + assert 400 < two_count < 600 + assert 233 < three_count < 433 + assert len(out) == 3000 + assert more_count == 0 + + def assert_proportions(self, out, samples): + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert np.isclose(one_count, two_count * 2, atol=samples * 0.2) + assert np.isclose(one_count, three_count * 3, atol=samples * 0.2) + assert len(out) == samples + assert more_count <= 1 + + def test_cardinality_rule_scale_called_multiple_times(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + self.assert_proportions( + pd.concat([ + first_reverse_transform, + second_reverse_transform, + third_reverse_transform, + fourth_reverse_transform, + ]), + 4611, + ) + + first_set = set(first_reverse_transform['col']) + second_set = set(second_reverse_transform['col']) + third_set = set(third_reverse_transform['col']) + fourth_set = set(fourth_reverse_transform['col']) + + assert len(first_set.intersection(second_set)) <= 1 + assert len(first_set.intersection(third_set)) <= 1 + assert len(first_set.intersection(fourth_set)) <= 1 + assert len(second_set.intersection(third_set)) <= 1 + assert len(second_set.intersection(fourth_set)) <= 1 + assert len(third_set.intersection(fourth_set)) <= 1 + + def test_cardinality_rule_scale_called_multiple_times_remaining_samples(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + hundred = [i // 100 for i in range(1000)] + two_hundred = [i // 200 for i in range(2000, 3000)] + data = pd.DataFrame({'col': hundred + two_hundred}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed_data = instance.fit_transform(data, 'col') + first_out = instance.reverse_transform(transformed_data.head(250)) + remaining_value = instance._remaining_samples['value'] + remaining_samples = instance._remaining_samples['repetitions'] + second_out = instance.reverse_transform(transformed_data) + + # Assert + assert len(first_out) == 250 + assert len(first_out[first_out['col'] == remaining_value]) == 50 + assert len(second_out['col']) == 2_000 + assert len(second_out[second_out['col'] == remaining_value]) == remaining_samples + class TestPsuedoAnonymizedFaker: def test_default_settings(self): diff --git a/tests/integration/transformers/test_id.py b/tests/integration/transformers/test_id.py index f3cb3fb5..b4d34ead 100644 --- a/tests/integration/transformers/test_id.py +++ b/tests/integration/transformers/test_id.py @@ -589,7 +589,7 @@ def assert_proportions(self, out, samples): assert np.isclose(one_count, two_count * 2, atol=samples * 0.2) assert np.isclose(one_count, three_count * 3, atol=samples * 0.2) assert len(out) == samples - assert more_count == 0 + assert more_count <= 1 def test_cardinality_rule_scale_not_enough_regex_numerical(self): """Test when cardinality rule is 'scale'.""" diff --git a/tests/unit/transformers/pii/test_anonymizer.py b/tests/unit/transformers/pii/test_anonymizer.py index c9dde11a..550a3abf 100644 --- a/tests/unit/transformers/pii/test_anonymizer.py +++ b/tests/unit/transformers/pii/test_anonymizer.py @@ -581,21 +581,21 @@ def test__reverse_transform_match_cardinality(self): AnonymizedFaker._reverse_transform(instance, None) # Assert - instance._reverse_transform_cardinality_rule_match.assert_called_once_with(3) + instance._reverse_transform_cardinality_rules.assert_called_once_with(3) - def test__reverse_transform_cardinality_rule_match_only_nans(self): + def test__reverse_transform_cardinality_rules_only_nans(self): """Test it with only nans.""" # Setup instance = AnonymizedFaker() instance._nan_frequency = 1 # Run - result = instance._reverse_transform_cardinality_rule_match(3) + result = instance._reverse_transform_cardinality_rules(3) # Assert assert pd.isna(result).all() - def test__reverse_transform_cardinality_rule_match_no_missing_value(self): + def test__reverse_transform_cardinality_rules_no_missing_value(self): """Test it with default values.""" # Setup instance = AnonymizedFaker(missing_value_generation=None) @@ -604,16 +604,37 @@ def test__reverse_transform_cardinality_rule_match_no_missing_value(self): instance._unique_categories = ['a', 'b', 'c'] function = Mock() function.side_effect = ['a', 'b', 'c'] - + instance.cardinality_rule = 'match' instance._function = function # Run - result = instance._reverse_transform_cardinality_rule_match(3) + result = instance._reverse_transform_cardinality_rules(3) # Assert assert set(result) == set(['a', 'b', 'c']) - def test__reverse_transform_cardinality_rule_match_not_enough_unique(self): + def test__reverse_transform_cardinality_rules_scale(self): + """Test it with scale cardinality.""" + # Setup + instance = AnonymizedFaker(missing_value_generation=None) + instance._data_cardinality = 2 + instance._nan_frequency = 0 + instance._data_cardinality_scale = { + 'num_repetitions': [1, 2, 3], + 'frequency': [0.1, 0.2, 0.7], + } + function = Mock() + function.side_effect = ['a', 'b', 'c'] + instance.cardinality_rule = 'scale' + instance._function = function + + # Run + result = instance._reverse_transform_cardinality_rules(3) + + # Assert + assert set(result).issubset(set(['a', 'b', 'c'])) + + def test__reverse_transform_cardinality_rules_not_enough_unique(self): """Test it when there are not enough unique values.""" # Setup instance = AnonymizedFaker() @@ -622,9 +643,10 @@ def test__reverse_transform_cardinality_rule_match_not_enough_unique(self): function = Mock() function.side_effect = ['a', 'b', 'c', 'd'] instance._function = function + instance.cardinality_rule = 'match' # Run - result = instance._reverse_transform_cardinality_rule_match(6) + result = instance._reverse_transform_cardinality_rules(6) # Assert assert set(result) == {'a', 'b', 'c'} @@ -747,28 +769,112 @@ def test__reverse_transform_size_is_length_of_data(self): assert function.call_args_list == [call(), call(), call()] np.testing.assert_array_equal(result, np.array(['a', 'b', 'c'])) + def test__reverse_transform_scale(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(data) + + # Assert + assert out[out == 'KAab'].size in {50, 100, 150} + assert out[out == 'qOSU'].size in {0, 50, 100} + + def test__reverse_transform_scale_multiple_calls(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data) + out2 = instance._reverse_transform(data) + instance.reset_randomization() + out3 = instance._reverse_transform(data) + + # Assert + assert out1[out1 == 'KAab'].size == 50 + assert out1[out1 == 'qOSU'].size == 50 + assert out1[out1 == 'CPmg'].size == 50 + + assert out2[out2 == 'urbw'].size == 50 + assert out2[out2 == 'JEWW'].size == 50 + assert out2[out2 == 'LRyt'].size == 50 + + assert out3[out3 == 'KAab'].size == 50 + assert out3[out3 == 'qOSU'].size == 50 + assert out1[out1 == 'CPmg'].size == 50 + + def test__reverse_transform_scale_remaining_values(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 10 + ['B'] * 3}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(8)) + out2 = instance._reverse_transform(data) + + # Assert + assert out1[out1 == 'qOSU'].size == 3 + assert out1[out1 == 'KAab'].size == 5 + assert out2[out2 == 'KAab'].size == 5 + assert out2[out2 == 'CPmg'].size == 8 + + def test__reverse_transform_scale_many_remaining_values(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(10)) + out2 = instance._reverse_transform(data.head(10)) + + # Assert + assert np.array_equal(out1, np.array(['qOSU'] * 10)) + assert np.array_equal(out2, np.array(['qOSU'] * 10)) + def test__set_fitted_parameters(self): """Test ``_set_fitted_parameters`` sets the required parameters for transformer.""" # Setup transformer = AnonymizedFaker() - transformer.cardinality_rule = 'match' frequency = 0.30 cardinality = 3 column_name = 'mock' - error_msg = re.escape('Cardinality "match" rule must specify a cardinality value.') # Run + transformer.cardinality_rule = 'match' + error_msg = re.escape('Cardinality "match" rule must specify a cardinality value.') + with pytest.raises(TransformerInputError, match=error_msg): + transformer._set_fitted_parameters(column_name, nan_frequency=frequency) + + transformer.cardinality_rule = 'scale' + error_msg = re.escape('Cardinality "scale" rule must specify a cardinality value.') with pytest.raises(TransformerInputError, match=error_msg): transformer._set_fitted_parameters(column_name, nan_frequency=frequency) transformer._set_fitted_parameters( - column_name, nan_frequency=frequency, cardinality=cardinality + column_name, + nan_frequency=frequency, + cardinality=cardinality, + cardinality_scale={'num_repetitions': [1, 2, 3], 'frequency': [0.1, 0.2, 0.7]}, ) # Assert assert transformer._nan_frequency == frequency assert transformer._data_cardinality == cardinality assert transformer.columns == [column_name] + assert transformer._data_cardinality_scale == { + 'num_repetitions': [1, 2, 3], + 'frequency': [0.1, 0.2, 0.7], + } def test___repr__default(self): """Test the ``__repr__`` method.