diff --git a/rdt/transformers/id.py b/rdt/transformers/id.py index 99adb3f4..1ae3a32a 100644 --- a/rdt/transformers/id.py +++ b/rdt/transformers/id.py @@ -8,7 +8,10 @@ from rdt.transformers.base import BaseTransformer from rdt.transformers.utils import ( + _get_cardinality_frequency, _handle_enforce_uniqueness_and_cardinality_rule, + _sample_repetitions, + fill_nan_with_none, strings_from_regex, ) @@ -97,16 +100,21 @@ class RegexGenerator(BaseTransformer): ``regex`` format. Args: - regex (str): + regex_format (str): String representing the regex function. enforce_uniqueness (bool): **DEPRECATED** Whether or not to ensure that the new generated data is all unique. If it isn't possible to create the requested number of rows, then an error will be raised. Defaults to ``None``. cardinality_rule (str): - Rule that the generated data must follow. If set to ``unique``, the generated - data must be unique. If set to ``None``, then the generated data may contain - duplicates. Defaults to ``None``. + Rule that the generated data must follow. + - If set to 'unique', the generated data must be unique. + - If set to 'match', the generated data will have the exact same cardinality + (number of unique values) as the real data. + - If set to 'scale', the generated data will match the number of repetitions that + each value is allowed to have. + - If set to ``None``, then the generated data may contain duplicates. + Defaults to ``None``. generation_order (str): String defining how to generate the output. If set to ``alphanumeric``, it will generate the output in alphanumeric order (ie. 'aaa', 'aab' or '1', '2'...). If @@ -156,19 +164,52 @@ def __init__( ) self.data_length = None self.generator = None - self.generator_size = None - self.generated = None if generation_order not in ['alphanumeric', 'scrambled']: raise ValueError("generation_order must be one of 'alphanumeric' or 'scrambled'.") self.generation_order = generation_order + # Used when cardinality_rule is 'scale' + self._data_cardinality_scale = None + self._remaining_samples = {'value': None, 'repetitions': 0} + + # Used when cardinality_rule is 'match' + self._data_cardinality = None + self._unique_regex_values = None + + # Used otherwise + self.generator_size = None + self.generated = None + def reset_randomization(self): """Create a new generator and reset the generated values counter.""" super().reset_randomization() self.generator, self.generator_size = strings_from_regex(self.regex_format) self.generated = 0 + if hasattr(self, 'cardinality_rule') and self.cardinality_rule == 'scale': + self._remaining_samples['repetitions'] = 0 + np.random.seed(self.random_seed) + + def _sample_fallback(self, num_samples, template_samples): + """Sample num_samples values such that they are all unique, disregarding the regex.""" + try: + # Integer-based fallback: attempt to convert the last template sample to an integer + # and then generate values in a sequential manner. + start = int(template_samples[-1]) + 1 + return [str(i) for i in range(start, start + num_samples)] + + except ValueError: + # String-based fallback: if the integer conversion fails, it uses the template + # samples as a base and appends a counter to make each value unique. + counter = 0 + samples = [] + while num_samples > len(samples): + samples.extend([f'{i}({counter})' for i in template_samples[:num_samples]]) + counter += 1 + + return samples[:num_samples] + def _fit(self, data): """Fit the transformer to the data. @@ -179,11 +220,23 @@ def _fit(self, data): self.reset_randomization() self.data_length = len(data) + if hasattr(self, 'cardinality_rule'): + data = fill_nan_with_none(data) + if self.cardinality_rule == 'match': + self._data_cardinality = data.nunique(dropna=False) + + elif self.cardinality_rule == 'scale': + sorted_counts, sorted_frequencies = _get_cardinality_frequency(data) + self._data_cardinality_scale = { + 'num_repetitions': sorted_counts, + 'frequency': sorted_frequencies, + } + def _transform(self, _data): """Drop the input column by returning ``None``.""" return None - def _warn_not_enough_unique_values(self, sample_size, unique_condition): + def _warn_not_enough_unique_values(self, sample_size, unique_condition, match_cardinality): """Warn the user that the regex cannot generate enough unique values. Args: @@ -191,15 +244,18 @@ def _warn_not_enough_unique_values(self, sample_size, unique_condition): Number of samples to be generated. unique_condition (bool): Whether or not to enforce uniqueness. + match_cardinality (bool): + Whether or not to match the cardinality of the data. """ warned = False + warn_msg = ( + f"The regex for '{self.get_input_column()}' can only generate " + f'{int(self.generator_size)} unique values. Additional values may not exactly ' + 'follow the provided regex.' + ) if sample_size > self.generator_size: if unique_condition: - warnings.warn( - f"The regex for '{self.get_input_column()}' can only generate " - f'{self.generator_size} unique values. Additional values may not exactly ' - 'follow the provided regex.' - ) + warnings.warn(warn_msg) warned = True else: LOGGER.info( @@ -218,6 +274,149 @@ def _warn_not_enough_unique_values(self, sample_size, unique_condition): f'values (only {max(remaining, 0)} unique values left).' ) + if match_cardinality: + if self._data_cardinality > sample_size: + warnings.warn( + f'Only {sample_size} values can be generated. Cannot match the cardinality ' + f'of the data, it requires {self._data_cardinality} values.' + ) + if sample_size > self.generator_size and self._data_cardinality > self.generator_size: + warnings.warn(warn_msg) + + def _sample_from_generator(self, num_samples): + """Generate samples. + + Generate values following the regex until either the sample size is reached or + the generator is exhausted. + """ + samples = [] + try: + while len(samples) < num_samples: + samples.append(next(self.generator)) + self.generated += 1 + except (RuntimeError, StopIteration): + pass + + return samples + + def _sample_from_template(self, num_samples, template_samples): + """Sample num_samples values from template_samples in a cycle. + + Eg: num_samples = 5, template_samples = ['a', 'b'] + The output will be ['a', 'b', 'a', 'b', 'a'] + """ + repeats = num_samples // len(template_samples) + 1 + return np.tile(template_samples, repeats)[:num_samples].tolist() + + def _sample_match(self, num_samples): + """Sample num_samples values following the 'match' cardinality rule.""" + samples = self._unique_regex_values[:num_samples] + if num_samples > len(samples): + new_samples = self._sample_from_template(num_samples - len(samples), samples) + samples.extend(new_samples) + + return samples + + def _sample_scale_fallback(self, num_samples, template_samples): + """Sample num_samples values, disregarding the regex, for the cardinality rule 'scale'.""" + warnings.warn( + f"The regex for '{self.get_input_column()}' cannot generate enough samples. " + 'Additional values may not exactly follow the provided regex.' + ) + samples = [] + fallback_samples = self._sample_fallback(num_samples, template_samples) + while num_samples > len(samples): + new_samples, self._remaining_samples = _sample_repetitions( + num_samples - len(samples), + fallback_samples.pop(0), + self._data_cardinality_scale.copy(), + self._remaining_samples.copy(), + ) + samples.extend(new_samples) + + return samples + + def _sample_repetitions_from_generator(self, num_samples): + """Sample num_samples values from the generator, or until the generator is exhausted.""" + samples = [self._remaining_samples['value']] * self._remaining_samples['repetitions'] + self._remaining_samples['repetitions'] = 0 + + template_samples = [] + while num_samples > len(samples): + try: + value = next(self.generator) + template_samples.append(value) + except (RuntimeError, StopIteration): + # If the generator is exhausted and no samples have been generated yet, reset it + if len(template_samples) == 0: + self.reset_randomization() + continue + else: + break + + new_samples, self._remaining_samples = _sample_repetitions( + num_samples - len(samples), + value, + self._data_cardinality_scale.copy(), + self._remaining_samples.copy(), + ) + samples.extend(new_samples) + + return samples, template_samples + + def _sample_scale(self, num_samples): + """Sample num_samples values following the 'scale' cardinality rule.""" + if self._remaining_samples['repetitions'] > num_samples: + self._remaining_samples['repetitions'] -= num_samples + return [self._remaining_samples['value']] * num_samples + + samples, template_samples = self._sample_repetitions_from_generator(num_samples) + if num_samples > len(samples): + new_samples = self._sample_scale_fallback(num_samples - len(samples), template_samples) + samples.extend(new_samples) + + return samples + + def _sample(self, num_samples, unique_condition): + """Sample num_samples values.""" + if num_samples <= 0: + return [] + + if hasattr(self, 'cardinality_rule'): + if self.cardinality_rule == 'match': + return self._sample_match(num_samples) + + if self.cardinality_rule == 'scale': + return self._sample_scale(num_samples) + + # If there aren't enough values left in the generator, reset it + if num_samples > self.generator_size - self.generated: + self.reset_randomization() + + samples = self._sample_from_generator(num_samples) + if num_samples > len(samples): + if unique_condition: + new_samples = self._sample_fallback(num_samples - len(samples), samples) + else: + new_samples = self._sample_from_template(num_samples - len(samples), samples) + samples.extend(new_samples) + + return samples + + def _generate_unique_regexes(self): + regex_values = [] + try: + while len(regex_values) < self._data_cardinality: + regex_values.append(next(self.generator)) + except (RuntimeError, StopIteration): + fallback_samples = self._sample_fallback( + num_samples=self._data_cardinality - len(regex_values), + template_samples=regex_values, + ) + regex_values.extend(fallback_samples) + + return regex_values + def _reverse_transform(self, data): """Generate new data using the provided ``regex_format``. @@ -230,56 +429,18 @@ def _reverse_transform(self, data): """ if hasattr(self, 'cardinality_rule'): unique_condition = self.cardinality_rule == 'unique' + match_cardinality = self.cardinality_rule == 'match' + if match_cardinality and self._unique_regex_values is None: + self._unique_regex_values = self._generate_unique_regexes() else: unique_condition = self.enforce_uniqueness + match_cardinality = False - if data is not None and len(data): - sample_size = len(data) - else: - sample_size = self.data_length - - self._warn_not_enough_unique_values(sample_size, unique_condition) - - remaining = self.generator_size - self.generated - if sample_size > remaining: - self.reset_randomization() - remaining = self.generator_size - - generated_values = [] - while len(generated_values) < sample_size: - try: - generated_values.append(next(self.generator)) - self.generated += 1 - except (RuntimeError, StopIteration): - # Can't generate more rows without collision so breaking out of loop - break - - reverse_transformed = generated_values[:] - - if len(reverse_transformed) < sample_size: - if unique_condition: - try: - remaining_samples = sample_size - len(reverse_transformed) - start = int(generated_values[-1]) + 1 - reverse_transformed.extend([ - str(i) for i in range(start, start + remaining_samples) - ]) - - except ValueError: - counter = 0 - while len(reverse_transformed) < sample_size: - remaining_samples = sample_size - len(reverse_transformed) - reverse_transformed.extend([ - f'{i}({counter})' for i in generated_values[:remaining_samples] - ]) - counter += 1 - - else: - while len(reverse_transformed) < sample_size: - remaining_samples = sample_size - len(reverse_transformed) - reverse_transformed.extend(generated_values[:remaining_samples]) + num_samples = len(data) if (data is not None and len(data)) else self.data_length + self._warn_not_enough_unique_values(num_samples, unique_condition, match_cardinality) + samples = self._sample(num_samples, unique_condition) if getattr(self, 'generation_order', 'alphanumeric') == 'scrambled': - np.random.shuffle(reverse_transformed) + np.random.shuffle(samples) - return np.array(reverse_transformed, dtype=object) + return np.array(samples, dtype=object) diff --git a/rdt/transformers/pii/anonymizer.py b/rdt/transformers/pii/anonymizer.py index fbd93ec7..74558828 100644 --- a/rdt/transformers/pii/anonymizer.py +++ b/rdt/transformers/pii/anonymizer.py @@ -15,7 +15,11 @@ from rdt.errors import TransformerInputError, TransformerProcessingError from rdt.transformers.base import BaseTransformer from rdt.transformers.categorical import LabelEncoder -from rdt.transformers.utils import _handle_enforce_uniqueness_and_cardinality_rule +from rdt.transformers.utils import ( + _get_cardinality_frequency, + _handle_enforce_uniqueness_and_cardinality_rule, + _sample_repetitions, +) class AnonymizedFaker(BaseTransformer): @@ -38,6 +42,8 @@ class AnonymizedFaker(BaseTransformer): cardinality_rule (str): If ``'unique'`` enforce that every created value is unique. If ``'match'`` match the cardinality of the data seen during fit. + If set to 'scale', the generated data will match the number of repetitions that + each value is allowed to have. If ``None`` do not consider cardinality. Defaults to ``None``. enforce_uniqueness (bool): @@ -118,6 +124,8 @@ def __init__( missing_value_generation='random', ): super().__init__() + self._data_cardinality_scale = None + self._remaining_samples = {'value': None, 'repetitions': 0} self._data_cardinality = None self.data_length = None self.enforce_uniqueness = enforce_uniqueness @@ -182,10 +190,14 @@ def reset_randomization(self): self.faker = faker.Faker(self.locales) self.faker.seed_instance(self._faker_random_seed) + if hasattr(self, 'cardinality_rule') and self.cardinality_rule == 'scale': + self._remaining_samples['repetitions'] = 0 + np.random.seed(self.random_seed) + def _function(self): """Return the result of calling the ``faker`` function.""" try: - if self.cardinality_rule in {'unique', 'match'}: + if self.cardinality_rule in {'unique', 'match', 'scale'}: faker_attr = self.faker.unique else: faker_attr = self.faker @@ -222,45 +234,45 @@ def _fit(self, data): self._set_faker_seed(data) self.data_length = len(data) if self.missing_value_generation == 'random': - self._nan_frequency = data.isna().sum() / len(data) + self._nan_frequency = data.isna().sum() / len(data) if len(data) > 0 else 0.0 if self.cardinality_rule == 'match': # remove nans from data self._data_cardinality = len(data.dropna().unique()) + if self.cardinality_rule == 'scale': + sorted_counts, sorted_frequencies = _get_cardinality_frequency(data.dropna()) + self._data_cardinality_scale = { + 'num_repetitions': sorted_counts, + 'frequency': sorted_frequencies, + } + def _transform(self, _data): """Drop the input column by returning ``None``.""" return None - def _get_unique_categories(self, samples): - return np.array([self._function() for _ in range(samples)], dtype=object) - - def _reverse_transform_cardinality_rule_match(self, sample_size): - """Reverse transform the data when the cardinality rule is 'match'.""" - num_nans = self._calculate_num_nans(sample_size) - reverse_transformed = self._generate_nans(num_nans) - - if sample_size <= num_nans: - return reverse_transformed - - remaining_samples = sample_size - num_nans - sampled_values = self._generate_cardinality_match_values(remaining_samples) - - reverse_transformed = np.concatenate([reverse_transformed, sampled_values]) - np.random.shuffle(reverse_transformed) - - return reverse_transformed - - def _calculate_num_nans(self, sample_size): - """Calculate the number of NaN values to generate.""" - if self.missing_value_generation == 'random': - return int(self._nan_frequency * sample_size) + def _generate_cardinality_scale_values(self, remaining_samples): + """Generate sampled values while ensuring each unique category appears at least once.""" + if self._remaining_samples['repetitions'] >= remaining_samples: + self._remaining_samples['repetitions'] -= remaining_samples + return [self._remaining_samples['value']] * remaining_samples + + samples = [self._remaining_samples['value']] * self._remaining_samples['repetitions'] + self._remaining_samples['repetitions'] = 0 + + while len(samples) < remaining_samples: + new_samples, self._remaining_samples = _sample_repetitions( + remaining_samples - len(samples), + self._function(), + self._data_cardinality_scale.copy(), + self._remaining_samples.copy(), + ) + samples.extend(new_samples) - return 0 + return np.array(samples, dtype=object) - def _generate_nans(self, num_nans): - """Generate an array of NaN values.""" - return np.full(num_nans, np.nan, dtype=object) + def _get_unique_categories(self, samples): + return np.array([self._function() for _ in range(samples)], dtype=object) def _generate_cardinality_match_values(self, remaining_samples): """Generate sampled values while ensuring each unique category appears at least once.""" @@ -295,6 +307,36 @@ def _reverse_transform_with_fallback(self, sample_size): return np.array(reverse_transformed, dtype=object) + def _calculate_num_nans(self, sample_size): + """Calculate the number of NaN values to generate.""" + if self.missing_value_generation == 'random': + return int(self._nan_frequency * sample_size) + + return 0 + + def _generate_nans(self, num_nans): + """Generate an array of NaN values.""" + return np.full(num_nans, np.nan, dtype=object) + + def _reverse_transform_cardinality_rules(self, sample_size): + """Reverse transform the data when the cardinality rule is 'match' or 'scale'.""" + num_nans = self._calculate_num_nans(sample_size) + reverse_transformed = self._generate_nans(num_nans) + + if sample_size <= num_nans: + return reverse_transformed + + remaining_samples = sample_size - num_nans + if self.cardinality_rule == 'match': + sampled_values = self._generate_cardinality_match_values(remaining_samples) + else: + sampled_values = self._generate_cardinality_scale_values(remaining_samples) + + reverse_transformed = np.concatenate([reverse_transformed, sampled_values]) + np.random.shuffle(reverse_transformed) + + return reverse_transformed + def _reverse_transform(self, data): """Generate new anonymized data using a ``faker.provider.function``. @@ -310,19 +352,25 @@ def _reverse_transform(self, data): else: sample_size = self.data_length - if hasattr(self, 'cardinality_rule') and self.cardinality_rule == 'match': - reverse_transformed = self._reverse_transform_cardinality_rule_match(sample_size) + if hasattr(self, 'cardinality_rule') and self.cardinality_rule in {'match', 'scale'}: + reverse_transformed = self._reverse_transform_cardinality_rules(sample_size) else: reverse_transformed = self._reverse_transform_with_fallback(sample_size) - if self.missing_value_generation == 'random' and not pd.isna(reverse_transformed).any(): + if self.missing_value_generation == 'random' and pd.notna(reverse_transformed).all(): num_nans = int(self._nan_frequency * sample_size) nan_indices = np.random.choice(sample_size, num_nans, replace=False) reverse_transformed[nan_indices] = np.nan return reverse_transformed - def _set_fitted_parameters(self, column_name, nan_frequency=0.0, cardinality=None): + def _set_fitted_parameters( + self, + column_name, + nan_frequency=0.0, + cardinality=None, + cardinality_scale=None, + ): """Manually set the parameters on the transformer to get it into a fitted state. Args: @@ -334,6 +382,12 @@ def _set_fitted_parameters(self, column_name, nan_frequency=0.0, cardinality=Non cardinality (int or None): The number of unique values to generate if cardinality rule is set to 'match'. + cardinality_scale (dict or None): + The frequency of each number of repetitions in the data: + { + 'num_repetitions': list of int, + 'frequency': list of float, + } """ self.reset_randomization() self.columns = [column_name] @@ -344,7 +398,14 @@ def _set_fitted_parameters(self, column_name, nan_frequency=0.0, cardinality=Non 'Cardinality "match" rule must specify a cardinality value.' ) + if self.cardinality_rule == 'scale': + if not cardinality_scale: + raise TransformerInputError( + 'Cardinality "scale" rule must specify a cardinality value.' + ) + self._data_cardinality = cardinality + self._data_cardinality_scale = cardinality_scale self._nan_frequency = nan_frequency def __repr__(self): diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index 297c4205..59f84fe3 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -376,7 +376,6 @@ def __getitem__(self, sdtype): def _handle_enforce_uniqueness_and_cardinality_rule(enforce_uniqueness, cardinality_rule): - result = cardinality_rule if enforce_uniqueness is not None: warnings.warn( "The 'enforce_uniqueness' parameter is no longer supported. " @@ -384,9 +383,14 @@ def _handle_enforce_uniqueness_and_cardinality_rule(enforce_uniqueness, cardinal FutureWarning, ) if enforce_uniqueness and cardinality_rule is None: - result = 'unique' + return 'unique' - return result + if cardinality_rule not in ['unique', 'match', 'scale', None]: + raise ValueError( + "The 'cardinality_rule' parameter must be one of 'unique', 'match', 'scale', or None." + ) + + return cardinality_rule def _extract_timezone_from_a_string(dt_str): @@ -466,3 +470,30 @@ def data_has_multiple_timezones(data): except ValueError: return False + + +def _get_cardinality_frequency(data): + """Get number of repetitions of values in the data and their frequencies.""" + value_counts = data.value_counts(dropna=False) + repetition_counts = value_counts.value_counts().sort_index() + total = repetition_counts.sum() + frequencies = (repetition_counts / total).tolist() + repetitions = repetition_counts.index.tolist() + + return repetitions, frequencies + + +def _sample_repetitions(num_samples, value, data_cardinality_scale, remaining_samples): + """Sample a number of repetitions for a given value.""" + repetitions = np.random.choice( + data_cardinality_scale['num_repetitions'], + p=data_cardinality_scale['frequency'], + ) + if repetitions <= num_samples: + samples = [value] * repetitions + else: + samples = [value] * num_samples + remaining_samples['repetitions'] = repetitions - num_samples + remaining_samples['value'] = value + + return samples, remaining_samples diff --git a/tests/integration/transformers/pii/test_anonymizer.py b/tests/integration/transformers/pii/test_anonymizer.py index 7a100f99..314c072d 100644 --- a/tests/integration/transformers/pii/test_anonymizer.py +++ b/tests/integration/transformers/pii/test_anonymizer.py @@ -318,6 +318,183 @@ def test_anonymized_faker_produces_only_n_values_for_each_reverse_transform_card # Assert assert set(first_reverse_transformed['name']) == set(second_reverse_transformed['name']) + def test_cardinality_rule_scale(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + assert set(out['col']) == {'KAab', 'qOSU'} + + value_counts = out['col'].value_counts() + assert value_counts['KAab'] == 50 + assert value_counts['qOSU'] == 100 + + def test_cardinality_rule_scale_nans(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': [np.nan] * 50 + ['B'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + assert value_counts['MGWz'] == 100 + assert out['col'].isna().sum() == 50 + + def test_cardinality_rule_scale_one_value(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + pd.testing.assert_frame_equal(out, pd.DataFrame({'col': ['qOSU'] * 50})) + + def test_cardinality_rule_scale_one_value_many_transform(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + instance.fit_transform(data, 'col') + out = instance.reverse_transform(pd.DataFrame(index=range(200))) + + # Assert + value_counts = out['col'].value_counts() + assert value_counts['qOSU'] == 50 + assert value_counts['JEWW'] == 50 + assert value_counts['KAab'] == 50 + assert value_counts['CPmg'] == 50 + + def test_cardinality_rule_scale_empty_data(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': []}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + pd.testing.assert_frame_equal(out, data, check_dtype=False) + + def test_cardinality_rule_scale_proportions(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert 900 < one_count < 1100 + assert 400 < two_count < 600 + assert 233 < three_count < 433 + assert len(out) == 3000 + assert more_count == 0 + + def assert_proportions(self, out, samples): + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert np.isclose(one_count, two_count * 2, atol=samples * 0.2) + assert np.isclose(one_count, three_count * 3, atol=samples * 0.2) + assert len(out) == samples + assert more_count <= 1 + + def test_cardinality_rule_scale_called_multiple_times(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + self.assert_proportions( + pd.concat([ + first_reverse_transform, + second_reverse_transform, + third_reverse_transform, + fourth_reverse_transform, + ]), + 4611, + ) + + first_set = set(first_reverse_transform['col']) + second_set = set(second_reverse_transform['col']) + third_set = set(third_reverse_transform['col']) + fourth_set = set(fourth_reverse_transform['col']) + + assert len(first_set.intersection(second_set)) <= 1 + assert len(first_set.intersection(third_set)) <= 1 + assert len(first_set.intersection(fourth_set)) <= 1 + assert len(second_set.intersection(third_set)) <= 1 + assert len(second_set.intersection(fourth_set)) <= 1 + assert len(third_set.intersection(fourth_set)) <= 1 + + def test_cardinality_rule_scale_called_multiple_times_remaining_samples(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + hundred = [i // 100 for i in range(1000)] + two_hundred = [i // 200 for i in range(2000, 3000)] + data = pd.DataFrame({'col': hundred + two_hundred}) + instance = AnonymizedFaker(cardinality_rule='scale') + + # Run + transformed_data = instance.fit_transform(data, 'col') + first_out = instance.reverse_transform(transformed_data.head(250)) + remaining_value = instance._remaining_samples['value'] + remaining_samples = instance._remaining_samples['repetitions'] + second_out = instance.reverse_transform(transformed_data) + + # Assert + assert len(first_out) == 250 + assert len(first_out[first_out['col'] == remaining_value]) == 50 + assert len(second_out['col']) == 2_000 + assert len(second_out[second_out['col'] == remaining_value]) == remaining_samples + class TestPsuedoAnonymizedFaker: def test_default_settings(self): diff --git a/tests/integration/transformers/test_id.py b/tests/integration/transformers/test_id.py index 8460ddb9..b4d34ead 100644 --- a/tests/integration/transformers/test_id.py +++ b/tests/integration/transformers/test_id.py @@ -1,4 +1,5 @@ import pickle +import warnings import numpy as np import pandas as pd @@ -335,6 +336,420 @@ def test_cardinality_rule_not_enough_values_numerical(self): expected = pd.DataFrame({'id': ['2', '3', '4', '5', '6']}, dtype=object) pd.testing.assert_frame_equal(reverse_transform, expected) + def test_cardinality_rule_match(self): + """Test with cardinality_rule='match'.""" + # Setup + data = pd.DataFrame({'id': [1, 2, 3, 4, 5]}) + instance = RegexGenerator('[1-3]{1}', cardinality_rule='match') + + # Run + transformed = instance.fit_transform(data, 'id') + reverse_transform = instance.reverse_transform(transformed) + + # Assert + expected = pd.DataFrame({'id': ['1', '2', '3', '4', '5']}, dtype=object) + pd.testing.assert_frame_equal(reverse_transform, expected) + + def test_cardinality_rule_match_not_enough_values(self): + """Test with cardinality_rule='match' but insufficient regex values.""" + # Setup + data = pd.DataFrame({'id': [1, 2, 3, 4, 5]}) + instance = RegexGenerator('[1-3]{1}', cardinality_rule='match') + + # Run + transformed = instance.fit_transform(data, 'id') + reverse_transform = instance.reverse_transform(transformed) + + # Assert + expected = pd.DataFrame({'id': ['1', '2', '3', '4', '5']}, dtype=object) + pd.testing.assert_frame_equal(reverse_transform, expected) + + def test_called_multiple_times_cardinality_rule_match(self): + """Test calling multiple times when ``cardinality_rule`` is ``match``.""" + # Setup + data = pd.DataFrame({'my_column': [1, 2, 3, 4, 5] * 3}) + generator = RegexGenerator(cardinality_rule='match') + + # Run + transformed_data = generator.fit_transform(data, 'my_column') + first_reverse_transform = generator.reverse_transform(transformed_data.head(3)) + second_reverse_transform = generator.reverse_transform(transformed_data.head(4)) + third_reverse_transform = generator.reverse_transform(transformed_data.head(5)) + fourth_reverse_transform = generator.reverse_transform(transformed_data.head(11)) + + # Assert + expected_first_reverse_transform = pd.DataFrame({'my_column': ['AAAAA', 'AAAAB', 'AAAAC']}) + expected_second_reverse_transform = pd.DataFrame({ + 'my_column': ['AAAAA', 'AAAAB', 'AAAAC', 'AAAAD'] + }) + expected_third_reverse_transform = pd.DataFrame({ + 'my_column': ['AAAAA', 'AAAAB', 'AAAAC', 'AAAAD', 'AAAAE'] + }) + expected_fourth_reverse_transform = pd.DataFrame({ + 'my_column': [ + 'AAAAA', + 'AAAAB', + 'AAAAC', + 'AAAAD', + 'AAAAE', + 'AAAAA', + 'AAAAB', + 'AAAAC', + 'AAAAD', + 'AAAAE', + 'AAAAA', + ] + }) + pd.testing.assert_frame_equal(first_reverse_transform, expected_first_reverse_transform) + pd.testing.assert_frame_equal(second_reverse_transform, expected_second_reverse_transform) + pd.testing.assert_frame_equal(third_reverse_transform, expected_third_reverse_transform) + pd.testing.assert_frame_equal(fourth_reverse_transform, expected_fourth_reverse_transform) + + def test_cardinality_rule_match_empty_regex(self): + """Test with cardinality_rule='match' but insufficient regex values.""" + # Setup + data = pd.DataFrame({'id': [1, 2, 3, 4, 5]}) + instance_unique = RegexGenerator('', cardinality_rule='unique') + instance_match = RegexGenerator('', cardinality_rule='match') + instance_none = RegexGenerator('') + + # Run + transformed_unique = instance_unique.fit_transform(data, 'id') + transformed_match = instance_match.fit_transform(data, 'id') + transformed_none = instance_none.fit_transform(data, 'id') + reverse_transform_unique = instance_unique.reverse_transform(transformed_unique) + reverse_transform_match = instance_match.reverse_transform(transformed_match) + reverse_transform_none = instance_none.reverse_transform(transformed_none) + + # Assert + expected = pd.DataFrame({'id': ['', '', '', '', '']}) + pd.testing.assert_frame_equal(reverse_transform_unique, expected) + pd.testing.assert_frame_equal(reverse_transform_match, expected) + pd.testing.assert_frame_equal(reverse_transform_none, expected) + + def test_cardinality_rule_scale(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + assert set(out['col']).issubset({'a', 'b', 'c'}) + + value_counts = out['col'].value_counts() + assert value_counts['a'] in {50, 100, 150} + assert value_counts.get('b', 0) in {0, 50, 100} + assert value_counts.get('c', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test_cardinality_rule_scale_nans(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': [np.nan] * 50 + ['B'] * 100}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + assert set(out['col']).issubset({'a', 'b', 'c'}) + + value_counts = out['col'].value_counts() + assert value_counts['a'] in {50, 100, 150} + assert value_counts.get('b', 0) in {0, 50, 100} + assert value_counts.get('c', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test_cardinality_rule_scale_one_value(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + pd.testing.assert_frame_equal(out, data) + + def test_cardinality_rule_scale_one_value_many_transform(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + instance.fit_transform(data, 'col') + out = instance.reverse_transform(pd.DataFrame(index=range(200))) + + assert len(recorded_warnings) == 0 + + # Assert + expected = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50 + ['D'] * 50}) + pd.testing.assert_frame_equal(out, expected) + + def test_cardinality_rule_scale_empty_data(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': []}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 0 + + # Assert + pd.testing.assert_frame_equal(out, data, check_dtype=False) + + def test_cardinality_rule_scale_proportions(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[a-z]{3}', cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert 900 < one_count < 1100 + assert 400 < two_count < 600 + assert 233 < three_count < 433 + assert len(out) == 3000 + assert more_count == 0 + + def test_cardinality_rule_scale_not_enough_regex_categorical(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + # Assert + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert 900 < one_count < 1100 + assert 400 < two_count < 600 + assert 233 < three_count < 433 + assert len(out) == 3000 + assert more_count == 0 + + def assert_proportions(self, out, samples): + value_counts = out['col'].value_counts() + one_count = (value_counts == 1).sum() + two_count = (value_counts == 2).sum() + three_count = (value_counts == 3).sum() + more_count = (value_counts > 3).sum() + + assert np.isclose(one_count, two_count * 2, atol=samples * 0.2) + assert np.isclose(one_count, three_count * 3, atol=samples * 0.2) + assert len(out) == samples + assert more_count <= 1 + + def test_cardinality_rule_scale_not_enough_regex_numerical(self): + """Test when cardinality rule is 'scale'.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[1-3]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed = instance.fit_transform(data, 'col') + out = instance.reverse_transform(transformed) + + assert len(recorded_warnings) == 1 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + assert warn_msg == str(recorded_warnings[0].message) + + # Assert + self.assert_proportions(out, 3000) + + def test_cardinality_rule_scale_called_multiple_times(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(cardinality_rule='scale', generation_order='alphanumeric') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + assert len(recorded_warnings) == 0 + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + + first_set = set(first_reverse_transform['col']) + second_set = set(second_reverse_transform['col']) + third_set = set(third_reverse_transform['col']) + assert first_set.isdisjoint(set(second_reverse_transform['col'][200:])) + assert first_set.isdisjoint(set(third_reverse_transform['col'][200:])) + assert first_set.isdisjoint(set(fourth_reverse_transform['col'][200:])) + assert second_set.isdisjoint(set(third_reverse_transform['col'][200:])) + assert second_set.isdisjoint(set(fourth_reverse_transform['col'][200:])) + assert third_set.isdisjoint(set(fourth_reverse_transform['col'][200:])) + + def test_cardinality_rule_scale_called_multiple_times_not_enough_regex(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[1-3]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + assert len(recorded_warnings) == 4 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + for warning in recorded_warnings: + assert warn_msg == str(warning.message) + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + + def test_cardinality_rule_scale_called_multiple_times_not_enough_regex_categorical(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + once = list(range(1000)) + twice = [i // 2 for i in range(2000, 3000)] + thrice = [i // 3 for i in range(4500, 5500)] + data = pd.DataFrame({'col': once + twice + thrice}) + instance = RegexGenerator(regex_format='[a-z]', cardinality_rule='scale') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_reverse_transform = instance.reverse_transform(transformed_data.head(500)) + second_reverse_transform = instance.reverse_transform(transformed_data.head(1000)) + third_reverse_transform = instance.reverse_transform(transformed_data.head(2000)) + fourth_reverse_transform = instance.reverse_transform(transformed_data.head(1111)) + + assert len(recorded_warnings) == 4 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + for warning in recorded_warnings: + assert warn_msg == str(warning.message) + + # Assert + self.assert_proportions(first_reverse_transform, 500) + self.assert_proportions(second_reverse_transform, 1000) + self.assert_proportions(third_reverse_transform, 2000) + self.assert_proportions(fourth_reverse_transform, 1111) + + def test_cardinality_rule_scale_called_multiple_times_remaining_samples(self): + """Test calling multiple times when ``cardinality_rule`` is ``scale``.""" + # Setup + hundred = [i // 100 for i in range(1000)] + two_hundred = [i // 200 for i in range(2000, 3000)] + data = pd.DataFrame({'col': hundred + two_hundred}) + instance = RegexGenerator( + regex_format='[a-f]', cardinality_rule='scale', generation_order='alphanumeric' + ) + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + transformed_data = instance.fit_transform(data, 'col') + first_out = instance.reverse_transform(transformed_data.head(250)) + second_out = instance.reverse_transform(transformed_data.head(3_000)) + + assert len(recorded_warnings) == 1 + warn_msg = ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + assert warn_msg == str(recorded_warnings[0].message) + + # Assert + assert len(first_out) == 250 + assert len(set(first_out['col'][200:])) == 1 + pd.testing.assert_series_equal( + first_out['col'][200:], + second_out['col'][:50], + check_index=False, + ) + assert second_out['col'][0] not in second_out['col'][50:] + class TestHyperTransformer: def test_end_to_end_scrambled(self): diff --git a/tests/unit/transformers/pii/test_anonymizer.py b/tests/unit/transformers/pii/test_anonymizer.py index c9dde11a..550a3abf 100644 --- a/tests/unit/transformers/pii/test_anonymizer.py +++ b/tests/unit/transformers/pii/test_anonymizer.py @@ -581,21 +581,21 @@ def test__reverse_transform_match_cardinality(self): AnonymizedFaker._reverse_transform(instance, None) # Assert - instance._reverse_transform_cardinality_rule_match.assert_called_once_with(3) + instance._reverse_transform_cardinality_rules.assert_called_once_with(3) - def test__reverse_transform_cardinality_rule_match_only_nans(self): + def test__reverse_transform_cardinality_rules_only_nans(self): """Test it with only nans.""" # Setup instance = AnonymizedFaker() instance._nan_frequency = 1 # Run - result = instance._reverse_transform_cardinality_rule_match(3) + result = instance._reverse_transform_cardinality_rules(3) # Assert assert pd.isna(result).all() - def test__reverse_transform_cardinality_rule_match_no_missing_value(self): + def test__reverse_transform_cardinality_rules_no_missing_value(self): """Test it with default values.""" # Setup instance = AnonymizedFaker(missing_value_generation=None) @@ -604,16 +604,37 @@ def test__reverse_transform_cardinality_rule_match_no_missing_value(self): instance._unique_categories = ['a', 'b', 'c'] function = Mock() function.side_effect = ['a', 'b', 'c'] - + instance.cardinality_rule = 'match' instance._function = function # Run - result = instance._reverse_transform_cardinality_rule_match(3) + result = instance._reverse_transform_cardinality_rules(3) # Assert assert set(result) == set(['a', 'b', 'c']) - def test__reverse_transform_cardinality_rule_match_not_enough_unique(self): + def test__reverse_transform_cardinality_rules_scale(self): + """Test it with scale cardinality.""" + # Setup + instance = AnonymizedFaker(missing_value_generation=None) + instance._data_cardinality = 2 + instance._nan_frequency = 0 + instance._data_cardinality_scale = { + 'num_repetitions': [1, 2, 3], + 'frequency': [0.1, 0.2, 0.7], + } + function = Mock() + function.side_effect = ['a', 'b', 'c'] + instance.cardinality_rule = 'scale' + instance._function = function + + # Run + result = instance._reverse_transform_cardinality_rules(3) + + # Assert + assert set(result).issubset(set(['a', 'b', 'c'])) + + def test__reverse_transform_cardinality_rules_not_enough_unique(self): """Test it when there are not enough unique values.""" # Setup instance = AnonymizedFaker() @@ -622,9 +643,10 @@ def test__reverse_transform_cardinality_rule_match_not_enough_unique(self): function = Mock() function.side_effect = ['a', 'b', 'c', 'd'] instance._function = function + instance.cardinality_rule = 'match' # Run - result = instance._reverse_transform_cardinality_rule_match(6) + result = instance._reverse_transform_cardinality_rules(6) # Assert assert set(result) == {'a', 'b', 'c'} @@ -747,28 +769,112 @@ def test__reverse_transform_size_is_length_of_data(self): assert function.call_args_list == [call(), call(), call()] np.testing.assert_array_equal(result, np.array(['a', 'b', 'c'])) + def test__reverse_transform_scale(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(data) + + # Assert + assert out[out == 'KAab'].size in {50, 100, 150} + assert out[out == 'qOSU'].size in {0, 50, 100} + + def test__reverse_transform_scale_multiple_calls(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data) + out2 = instance._reverse_transform(data) + instance.reset_randomization() + out3 = instance._reverse_transform(data) + + # Assert + assert out1[out1 == 'KAab'].size == 50 + assert out1[out1 == 'qOSU'].size == 50 + assert out1[out1 == 'CPmg'].size == 50 + + assert out2[out2 == 'urbw'].size == 50 + assert out2[out2 == 'JEWW'].size == 50 + assert out2[out2 == 'LRyt'].size == 50 + + assert out3[out3 == 'KAab'].size == 50 + assert out3[out3 == 'qOSU'].size == 50 + assert out1[out1 == 'CPmg'].size == 50 + + def test__reverse_transform_scale_remaining_values(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 10 + ['B'] * 3}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(8)) + out2 = instance._reverse_transform(data) + + # Assert + assert out1[out1 == 'qOSU'].size == 3 + assert out1[out1 == 'KAab'].size == 5 + assert out2[out2 == 'KAab'].size == 5 + assert out2[out2 == 'CPmg'].size == 8 + + def test__reverse_transform_scale_many_remaining_values(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 100}) + instance = AnonymizedFaker(cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(10)) + out2 = instance._reverse_transform(data.head(10)) + + # Assert + assert np.array_equal(out1, np.array(['qOSU'] * 10)) + assert np.array_equal(out2, np.array(['qOSU'] * 10)) + def test__set_fitted_parameters(self): """Test ``_set_fitted_parameters`` sets the required parameters for transformer.""" # Setup transformer = AnonymizedFaker() - transformer.cardinality_rule = 'match' frequency = 0.30 cardinality = 3 column_name = 'mock' - error_msg = re.escape('Cardinality "match" rule must specify a cardinality value.') # Run + transformer.cardinality_rule = 'match' + error_msg = re.escape('Cardinality "match" rule must specify a cardinality value.') + with pytest.raises(TransformerInputError, match=error_msg): + transformer._set_fitted_parameters(column_name, nan_frequency=frequency) + + transformer.cardinality_rule = 'scale' + error_msg = re.escape('Cardinality "scale" rule must specify a cardinality value.') with pytest.raises(TransformerInputError, match=error_msg): transformer._set_fitted_parameters(column_name, nan_frequency=frequency) transformer._set_fitted_parameters( - column_name, nan_frequency=frequency, cardinality=cardinality + column_name, + nan_frequency=frequency, + cardinality=cardinality, + cardinality_scale={'num_repetitions': [1, 2, 3], 'frequency': [0.1, 0.2, 0.7]}, ) # Assert assert transformer._nan_frequency == frequency assert transformer._data_cardinality == cardinality assert transformer.columns == [column_name] + assert transformer._data_cardinality_scale == { + 'num_repetitions': [1, 2, 3], + 'frequency': [0.1, 0.2, 0.7], + } def test___repr__default(self): """Test the ``__repr__`` method. diff --git a/tests/unit/transformers/test_id.py b/tests/unit/transformers/test_id.py index f8b7f573..0f7ff0e9 100644 --- a/tests/unit/transformers/test_id.py +++ b/tests/unit/transformers/test_id.py @@ -1,6 +1,7 @@ """Test for ID transformers.""" import re +import warnings from string import ascii_uppercase from unittest.mock import Mock, patch @@ -178,6 +179,10 @@ def test___getstate__(self): 'regex_format': '[A-Za-z]{5}', 'random_states': mock_random_sates, 'generation_order': 'alphanumeric', + '_unique_regex_values': None, + '_data_cardinality': None, + '_data_cardinality_scale': None, + '_remaining_samples': {'value': None, 'repetitions': 0}, } @patch('rdt.transformers.id.strings_from_regex') @@ -266,6 +271,21 @@ def test___init__custom(self): assert instance.cardinality_rule == 'unique' assert instance.generation_order == 'scrambled' + def test___init__cardinality_rule_match(self): + """Test it when cardinality_rule is 'match'.""" + # Run + instance = RegexGenerator( + regex_format='[0-9]', + cardinality_rule='match', + ) + + # Assert + assert instance.data_length is None + assert instance.regex_format == '[0-9]' + assert instance.cardinality_rule == 'match' + assert instance._data_cardinality is None + assert instance._unique_regex_values is None + def test___init__bad_value_generation_order(self): """Test that an error is raised if a bad value is given for `generation_order`.""" # Run and Assert @@ -347,6 +367,194 @@ def test__fit(self): assert instance.data_length == 3 assert instance.output_properties == {None: {'next_transformer': None}} + def test__fit_cardinality_rule_match(self): + """Test the fit method when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match') + columns_data = pd.Series(['1', '2', '3', '2', '1']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 5 + assert instance._data_cardinality == 3 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match(self): + """Test the reverse transform method when cardinality_rule is 'match'.""" + # Setup + instance = RegexGenerator(cardinality_rule='match') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '1']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['AAAAA', 'AAAAB', 'AAAAC'] + + def test__fit_cardinality_rule_match_with_regex_format(self): + """Test the fit method when cardinality_rule is 'match' and regex_format is provided.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.Series(['1', '2', '3', '2', '1']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 5 + assert instance._data_cardinality == 3 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_regex_format(self): + """Test reverse transform when cardinality_rule is 'match' and regex_format is provided.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '1']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['1', '2', '3'] + + def test__fit_cardinality_rule_match_with_nans(self): + """Test the fit method when cardinality_rule is 'match' and there are nans.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.Series(['1', '2', '3', '2', '1', np.nan, None, np.nan]) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 8 + assert instance._data_cardinality == 4 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_nans(self): + """Test the reverse transform method when cardinality_rule is 'match' and there are nans.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-5]{1}') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '1', np.nan, None, np.nan]}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['1', '2', '3', '4'] + + def test__fit_cardinality_rule_match_with_nans_too_many_values(self): + """Test the fit method when cardinality_rule is 'match' and values don't follow regex.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-3]{1}') + columns_data = pd.Series(['1', '2', '3', '2', '4']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 5 + assert instance._data_cardinality == 4 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_nans_too_many_values(self): + """Test reverse transform when cardinality_rule is 'match' and values don't follow regex.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[1-3]{1}') + columns_data = pd.DataFrame({'col': ['1', '2', '3', '2', '4']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['1', '2', '3', '4'] + + def test__fit_cardinality_rule_match_with_too_many_values_str(self): + """Test fit when cardinality_rule is 'match' and string values don't follow regex.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[a-b]{1}') + columns_data = pd.Series(['a', 'b', 'c', 'b', 'd', 'f']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 6 + assert instance._data_cardinality == 5 + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__reverse_transform_cardinality_rule_match_with_too_many_values_str(self): + """Test reverse transform when cardinality_rule='match' and strings don't follow regex.""" + # Setup + instance = RegexGenerator(cardinality_rule='match', regex_format='[a-b]{1}') + columns_data = pd.DataFrame({'col': ['a', 'b', 'c', 'b', 'd', 'f']}) + + # Run + instance.fit(columns_data, 'col') + instance._reverse_transform(columns_data) + + # Assert + assert instance._unique_regex_values == ['a', 'b', 'a(0)', 'b(0)', 'a(1)'] + assert instance._unique_regex_values == ['a', 'b', 'a(0)', 'b(0)', 'a(1)'] + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__fit_cardinality_rule_scale(self): + """Test fit when cardinality_rule is 'scale'.""" + # Setup + instance = RegexGenerator(cardinality_rule='scale') + columns_data = pd.Series(['1', '2', '3', '4', '5', '5', '6', '6', '7', '7', '7', '7']) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 12 + assert instance._data_cardinality_scale == { + 'num_repetitions': [1, 2, 4], + 'frequency': [4 / 7, 2 / 7, 1 / 7], + } + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__fit_cardinality_rule_scale_nans(self): + """Test fit when cardinality_rule is 'scale' and there are nans.""" + # Setup + instance = RegexGenerator(cardinality_rule='scale') + columns_data = pd.Series([np.nan, np.nan, None, None, '1', 2]) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 6 + assert instance._data_cardinality_scale == { + 'num_repetitions': [1, 4], + 'frequency': [2 / 3, 1 / 3], + } + assert instance.output_properties == {None: {'next_transformer': None}} + + def test__fit_cardinality_rule_scale_only_nans(self): + """Test fit when cardinality_rule is 'scale' and there are only nans.""" + # Setup + instance = RegexGenerator(cardinality_rule='scale') + columns_data = pd.Series([np.nan, np.nan, None, None, float('nan'), float('nan')]) + + # Run + instance._fit(columns_data) + + # Assert + assert instance.data_length == 6 + assert instance._data_cardinality_scale == { + 'num_repetitions': [6], + 'frequency': [1], + } + assert instance.output_properties == {None: {'next_transformer': None}} + def test__transform(self): """Test the ``_transform`` method. @@ -386,6 +594,7 @@ def test__reverse_transform_generation_order_scrambled(self, shuffle_mock): instance.generator_size = 5 instance.generated = 0 instance.generation_order = 'scrambled' + instance.columns = ['col'] # Run result = instance._reverse_transform(columns_data) @@ -417,6 +626,7 @@ def test__reverse_transform_generator_size_bigger_than_data_length(self): instance.generator = generator instance.generator_size = 5 instance.generated = 0 + instance.columns = ['col'] # Run result = instance._reverse_transform(columns_data) @@ -625,3 +835,240 @@ def test__reverse_transform_info_message(self, mock_logger): expected_args = (6, 'a', 5, 'a') mock_logger.info.assert_called_once_with(expected_format, *expected_args) + + def test__reverse_transform_match_not_enough_values(self): + """Test the case when there are not enough values to match the cardinality rule.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg = re.escape( + 'Only 3 values can be generated. Cannot match the cardinality ' + 'of the data, it requires 5 values.' + ) + with pytest.warns(UserWarning, match=warn_msg): + out = instance._reverse_transform(data[:3]) + + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C'])) + + def test__reverse_transform_match_too_many_samples(self): + """Test it when the number of samples is bigger than the generator size.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-C]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg = re.escape( + "The regex for 'col' can only generate 3 unique values. Additional values may not " + 'exactly follow the provided regex.' + ) + with pytest.warns(UserWarning, match=warn_msg): + out = instance._reverse_transform(data) + + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'A(0)', 'B(0)'])) + + def test__reverse_transform_only_one_warning(self): + """Test it when the num_samples < generator_size but data_cardinality > generator_size.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-C]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg = re.escape( + 'Only 2 values can be generated. Cannot match the cardinality of the data, ' + 'it requires 5 values.' + ) + with pytest.warns(UserWarning, match=warn_msg): + out = instance._reverse_transform(data[:2]) + + np.testing.assert_array_equal(out, np.array(['A', 'B'])) + instance._unique_regex_values = ['A', 'B', 'C', 'A(0)', 'B(0)'] + + def test__reverse_transform_two_warnings(self): + """Test it when data_cardinality > num_samples > generator_size.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-C]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run and Assert + warn_msg_1 = re.escape( + 'Only 4 values can be generated. Cannot match the cardinality of the data, ' + 'it requires 5 values.' + ) + warn_msg_2 = re.escape( + "The regex for 'col' can only generate 3 unique values. Additional values may " + 'not exactly follow the provided regex.' + ) + with pytest.warns(UserWarning, match=warn_msg_1): + with pytest.warns(UserWarning, match=warn_msg_2): + out = instance._reverse_transform(data[:4]) + + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'A(0)'])) + + def test__reverse_transform_match_empty_data(self): + """Test it when the data is empty and the cardinality rule is 'match'.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'C', 'D', 'E']}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(pd.Series()) + + # Assert + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'D', 'E'])) + + def test__reverse_transform_match_with_nans(self): + """Test it when the data has nans and the cardinality rule is 'match'.""" + # Setup + data = pd.DataFrame({'col': ['A', np.nan, np.nan, 'D', np.nan]}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(data) + + # Assert + np.testing.assert_array_equal(out, np.array(['A', 'B', 'C', 'A', 'B'])) + + def test__reverse_transform_match_too_many_values(self): + """Test it when the data has more values than the cardinality rule.""" + # Setup + data = pd.DataFrame({'col': ['A', 'B', 'B', 'C']}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='match') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(pd.Series([1] * 10)) + + # Assert + np.testing.assert_array_equal( + out, np.array(['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A']) + ) + + def test__reverse_transform_scale(self): + """Test when cardinality rule is 'scale'.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 100}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(data) + + # Assert + assert set(out).issubset({'A', 'B', 'C'}) + + value_counts = pd.Series(out).value_counts() + assert value_counts['A'] in {50, 100, 150} + assert value_counts.get('B', 0) in {0, 50, 100} + assert value_counts.get('C', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test__reverse_transform_scale_empty_data(self): + """Test when cardinality rule is 'scale' and the data is empty.""" + # Setup + data = pd.DataFrame({'col': []}) + instance = RegexGenerator(regex_format='[A-Z]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out = instance._reverse_transform(data) + + # Assert + assert isinstance(out, np.ndarray) + assert out.size == 0 + + def test__reverse_transform_scale_not_enough_regex(self): + """Test when cardinality rule is 'scale' and the regex cannot generate enough samples.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + out = instance._reverse_transform(data) + + assert len(recorded_warnings) == 1 + assert str(recorded_warnings[0].message) == ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + + # Assert + assert set(out).issubset({'A', 'B', 'A(0)'}) + + value_counts = pd.Series(out).value_counts() + assert value_counts['A'] in {50, 100, 150} + assert value_counts.get('B', 0) in {0, 50, 100} + assert value_counts.get('A(0)', 0) in {0, 50} + + assert value_counts.sum() == 150 + + def test__reverse_transform_scale_not_enough_regex_multiple_calls(self): + """Test with cardinality_rule='scale' and multiple calls with not enough regex.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 50 + ['B'] * 50 + ['C'] * 50}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter('always') + out1 = instance._reverse_transform(data) + out2 = instance._reverse_transform(data) + + assert len(recorded_warnings) == 2 + assert str(recorded_warnings[0].message) == ( + "The regex for 'col' cannot generate enough samples. Additional values " + 'may not exactly follow the provided regex.' + ) + + # Assert + assert set(out1) == {'A', 'B', 'A(0)'} + assert set(out2) == {'A', 'B', 'A(0)'} + + for out in [out1, out2]: + value_counts = pd.Series(out).value_counts() + assert value_counts['A'] in {50, 100, 150} + assert value_counts.get('B', 0) in {0, 50, 100} + assert value_counts.get('A(0)', 0) in {0, 50} + assert value_counts.sum() == 150 + + def test__reverse_transform_scale_remaining_values(self): + """Test with cardinality_rule='scale' and a call has remaining values for next call.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 10 + ['B'] * 3}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(8)) + out2 = instance._reverse_transform(data) + + # Assert + assert set(out1).issubset({'A', 'B', 'A(0)'}) + assert set(out2).issubset({'A', 'B', 'A(0)', 'B(0)', 'B(1)', 'B(2)', 'B(3)'}) + + def test__reverse_transform_scale_many_remaining_values(self): + """Test with cardinality_rule='scale' and a call has many remaining values for next call.""" + # Setup + data = pd.DataFrame({'col': ['A'] * 100}) + instance = RegexGenerator(regex_format='[A-B]', cardinality_rule='scale') + instance.fit(data, 'col') + + # Run + out1 = instance._reverse_transform(data.head(10)) + out2 = instance._reverse_transform(data.head(10)) + + # Assert + assert np.array_equal(out1, np.array(['A'] * 10)) + assert np.array_equal(out2, np.array(['A'] * 10)) diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index 2bb851aa..f430828c 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -479,26 +479,24 @@ def test_warn_dict_get(): def test__handle_enforce_uniqueness_and_cardinality_rule(): """Test that ``_handle_enforce_uniqueness_and_cardinality_rule`` works as expected.""" - # Setup - enforce_uniqueness = None - cardinality_rule = None + # Run and Assert + assert _handle_enforce_uniqueness_and_cardinality_rule(None, None) is None + expected_message = re.escape( "The 'enforce_uniqueness' parameter is no longer supported. " "Please use the 'cardinality_rule' parameter instead." ) - - # Run - result_1 = _handle_enforce_uniqueness_and_cardinality_rule(enforce_uniqueness, cardinality_rule) with pytest.warns(FutureWarning, match=expected_message): - result_2 = _handle_enforce_uniqueness_and_cardinality_rule(True, None) + assert _handle_enforce_uniqueness_and_cardinality_rule(True, None) == 'unique' - with pytest.warns(FutureWarning, match=expected_message): - result_3 = _handle_enforce_uniqueness_and_cardinality_rule(True, 'other') + err_msg = "The 'cardinality_rule' parameter must be one of 'unique', 'match', 'scale', or None." + with pytest.raises(ValueError, match=err_msg): + _handle_enforce_uniqueness_and_cardinality_rule(None, 'invalid') - # Assert - assert result_1 is None - assert result_2 == 'unique' - assert result_3 == 'other' + assert _handle_enforce_uniqueness_and_cardinality_rule(None, 'unique') == 'unique' + assert _handle_enforce_uniqueness_and_cardinality_rule(None, 'match') == 'match' + assert _handle_enforce_uniqueness_and_cardinality_rule(None, 'scale') == 'scale' + assert _handle_enforce_uniqueness_and_cardinality_rule(None, None) is None def test__extract_timezone_from_a_string_with_valid_timezone():