diff --git a/sdv/cag/_utils.py b/sdv/cag/_utils.py index 72cedaedb..9bf4adbd0 100644 --- a/sdv/cag/_utils.py +++ b/sdv/cag/_utils.py @@ -109,7 +109,7 @@ def _get_is_valid_dict(data, table_name): return { table: pd.Series(True, index=table_data.index) for table, table_data in data.items() - if table != table_name + if table != table_name or table_name is None } diff --git a/sdv/cag/base.py b/sdv/cag/base.py index 9e182fc3b..83d446013 100644 --- a/sdv/cag/base.py +++ b/sdv/cag/base.py @@ -253,10 +253,10 @@ def reverse_transform(self, data): return reverse_transformed - def _is_valid(self, data): + def _is_valid(self, data, metadata): raise NotImplementedError - def is_valid(self, data): + def is_valid(self, data, metadata=None): """Say whether the given table rows are valid. Args: @@ -267,14 +267,21 @@ def is_valid(self, data): pd.Series or dict[pd.Series]: Series of boolean values indicating if the row is valid for the constraint or not. """ - if not self._fitted: + if not self._fitted and metadata is None: raise NotFittedError( - 'Constraint must be fit using ``fit`` before determining if data is valid.' + 'Constraint must be fit using ``fit`` before determining if data is valid ' + 'without providing metadata.' ) - data = self._convert_data_to_dictionary(data, self.metadata) - is_valid_data = self._is_valid(data) - if self._single_table: - return is_valid_data[self._table_name] + metadata = self.metadata if metadata is None else metadata + data_dict = self._convert_data_to_dictionary(data, metadata) + is_valid_data = self._is_valid(data_dict, metadata) + if isinstance(data, pd.DataFrame) or self._single_table: + table_name = ( + self._get_single_table_name(metadata) + if getattr(self, '_table_name', None) is None + else self._table_name + ) + return is_valid_data[table_name] return is_valid_data diff --git a/sdv/cag/fixed_combinations.py b/sdv/cag/fixed_combinations.py index b58ab8ef8..84780f015 100644 --- a/sdv/cag/fixed_combinations.py +++ b/sdv/cag/fixed_combinations.py @@ -182,8 +182,11 @@ def _reverse_transform(self, data): return data - def _is_valid(self, data): + def _is_valid(self, data, metadata): """Determine whether the data matches the constraint.""" + if not self._fitted: + return _get_is_valid_dict(data, table_name=None) + table_name = self._get_single_table_name(self.metadata) is_valid = _get_is_valid_dict(data, table_name) merged = data[table_name].merge( diff --git a/sdv/cag/fixed_increments.py b/sdv/cag/fixed_increments.py index c251e90e8..6197644c5 100644 --- a/sdv/cag/fixed_increments.py +++ b/sdv/cag/fixed_increments.py @@ -152,7 +152,7 @@ def _fit(self, data, metadata): table_name = self._get_single_table_name(metadata) self._dtype = data[table_name][self.column_name].dtype - def _is_valid(self, data): + def _is_valid(self, data, metadata): """Determine if the data is evenly divisible by the increment. Args: @@ -169,7 +169,7 @@ def _is_valid(self, data): number of tables in the data and contain the same table names. """ - table_name = self._get_single_table_name(self.metadata) + table_name = self._get_single_table_name(metadata) is_valid = _get_is_valid_dict(data, table_name) valid = self._check_if_divisible(data, table_name, self.column_name, self.increment_value) is_valid[table_name] = valid diff --git a/sdv/cag/inequality.py b/sdv/cag/inequality.py index 5b3bdd9c1..b73a96835 100644 --- a/sdv/cag/inequality.py +++ b/sdv/cag/inequality.py @@ -119,16 +119,10 @@ def _get_is_datetime(self, metadata, table_name): def _get_datetime_format(self, metadata, table_name, column_name): return metadata.tables[table_name].columns[column_name].get('datetime_format') - def _validate_constraint_with_data(self, data, metadata): - """Validate the data is compatible with the constraint. - - Validate that the inequality requirement is met between the high and low columns. - """ - table_name = self._get_single_table_name(metadata) - data = data[table_name] - low, high = self._get_data(data) + def _get_valid_table_data(self, table_data, metadata, table_name): + low, high = self._get_data(table_data) is_datetime = self._get_is_datetime(metadata, table_name) - if is_datetime and is_object_dtype(data[self._low_column_name]): + if is_datetime and is_object_dtype(table_data[self._low_column_name]): low_format = self._get_datetime_format(metadata, table_name, self._low_column_name) high_format = self._get_datetime_format(metadata, table_name, self._high_column_name) low = cast_to_datetime64(low, low_format) @@ -143,7 +137,15 @@ def _validate_constraint_with_data(self, data, metadata): high_datetime_format=high_format, ) - valid = pd.isna(low) | pd.isna(high) | self._operator(high, low) + return pd.isna(low) | pd.isna(high) | self._operator(high, low) + + def _validate_constraint_with_data(self, data, metadata): + """Validate the data is compatible with the constraint. + + Validate that the inequality requirement is met between the high and low columns. + """ + table_name = self._get_single_table_name(metadata) + valid = self._get_valid_table_data(data[table_name], metadata, table_name) if not valid.all(): invalid_rows = _get_invalid_rows(valid) raise ConstraintNotMetError( @@ -293,7 +295,7 @@ def _reverse_transform(self, data): return data - def _is_valid(self, data): + def _is_valid(self, data, metadata): """Check whether `high` is greater than `low` in each row. Args: @@ -304,24 +306,9 @@ def _is_valid(self, data): dict[str, pd.Series]: Whether each row is valid. """ - table_name = self._get_single_table_name(self.metadata) + table_name = self._get_single_table_name(metadata) is_valid = _get_is_valid_dict(data, table_name) - table_data = data[table_name] - low, high = self._get_data(table_data) - if self._is_datetime and is_object_dtype(self._dtype): - low = cast_to_datetime64(low, self._low_datetime_format) - high = cast_to_datetime64(high, self._high_datetime_format) - - format_matches = bool(self._low_datetime_format == self._high_datetime_format) - if not format_matches: - low, high = match_datetime_precision( - low=low, - high=high, - low_datetime_format=self._low_datetime_format, - high_datetime_format=self._high_datetime_format, - ) - - valid = pd.isna(low) | pd.isna(high) | self._operator(high, low) - is_valid[table_name] = valid + valid_table_rows = self._get_valid_table_data(data[table_name], metadata, table_name) + is_valid[table_name] = valid_table_rows return is_valid diff --git a/sdv/cag/one_hot_encoding.py b/sdv/cag/one_hot_encoding.py index a376bc5c8..026e6a644 100644 --- a/sdv/cag/one_hot_encoding.py +++ b/sdv/cag/one_hot_encoding.py @@ -127,7 +127,7 @@ def _reverse_transform(self, data): return data - def _is_valid(self, data): + def _is_valid(self, data, metadata): """Check whether the data satisfies the one-hot constraint. Args: @@ -138,7 +138,7 @@ def _is_valid(self, data): dict[str, pd.Series]: Whether each row is valid. """ - table_name = self._get_single_table_name(self.metadata) + table_name = self._get_single_table_name(metadata) is_valid = _get_is_valid_dict(data, table_name) is_valid[table_name] = self._get_valid_table_data(data[table_name]) diff --git a/sdv/cag/programmable_constraint.py b/sdv/cag/programmable_constraint.py index 343bb51cc..470efae3a 100644 --- a/sdv/cag/programmable_constraint.py +++ b/sdv/cag/programmable_constraint.py @@ -187,13 +187,14 @@ def _reverse_transform(self, data): return reverse_transformed - def _is_valid(self, data): + def _is_valid(self, data, metadata): if self._is_single_table: - data = data[self._table_name] + table_name = self._get_single_table_name(metadata) + data = data[table_name] is_valid = self.programmable_constraint.is_valid(data) if self._is_single_table: - return {self._table_name: is_valid} + return {table_name: is_valid} return is_valid diff --git a/sdv/cag/range.py b/sdv/cag/range.py index 7383696e6..e2e4cfc0f 100644 --- a/sdv/cag/range.py +++ b/sdv/cag/range.py @@ -354,7 +354,7 @@ def _reverse_transform(self, data): return data - def _is_valid(self, data): + def _is_valid(self, data, metadata): """Check whether the `middle` column is between the `low` and `high` columns. Args: @@ -365,7 +365,7 @@ def _is_valid(self, data): dict[str, pd.Series]: Whether each row is valid. """ - table_name = self._get_single_table_name(self.metadata) + table_name = self._get_single_table_name(metadata) is_valid = _get_is_valid_dict(data, table_name) is_valid[table_name] = self._get_valid_table_data(data[table_name]) diff --git a/tests/integration/cag/test_programmable_constraint.py b/tests/integration/cag/test_programmable_constraint.py index f5109a0d4..4929e30c5 100644 --- a/tests/integration/cag/test_programmable_constraint.py +++ b/tests/integration/cag/test_programmable_constraint.py @@ -99,6 +99,7 @@ def __init__(self, column_names, table_name): self.table_name = table_name self._joint_column = '#'.join(self.column_names) self._combinations = None + self._fitted = False def _get_single_table_name(self, metadata): # Have to define this so that we can re-use existing methods on the constraint @@ -114,6 +115,7 @@ def fit(self, data, metadata): self.metadata = metadata data = {self.table_name: data} FixedCombinations._fit(self, data, metadata) + self._fitted = True def transform(self, data): data = {self.table_name: data} @@ -130,7 +132,7 @@ def reverse_transform(self, transformed_data): def is_valid(self, synthetic_data): synthetic_data = {self.table_name: synthetic_data} - is_valid = FixedCombinations._is_valid(self, synthetic_data) + is_valid = FixedCombinations._is_valid(self, synthetic_data, self.metadata) return is_valid[self.table_name] return MyConstraint diff --git a/tests/unit/cag/test_base.py b/tests/unit/cag/test_base.py index efce7a00c..3b7acbe75 100644 --- a/tests/unit/cag/test_base.py +++ b/tests/unit/cag/test_base.py @@ -560,7 +560,8 @@ def test_is_valid_errors_if_not_fitted(self, data): # Setup instance = BaseConstraint() expected_msg = re.escape( - 'Constraint must be fit using ``fit`` before determining if data is valid.' + 'Constraint must be fit using ``fit`` before determining ' + 'if data is valid without providing metadata.' ) # Run and assert @@ -573,12 +574,13 @@ def test_is_valid(self, data): instance = BaseConstraint() instance._is_valid = Mock() instance._fitted = True + instance.metadata = Mock() # Run is_valid_result = instance.is_valid(data) # Assert - instance._is_valid.assert_called_once_with(data) + instance._is_valid.assert_called_once_with(data, instance.metadata) assert is_valid_result == instance._is_valid.return_value def test_is_valid_single_table(self, data): @@ -591,10 +593,13 @@ def test_is_valid_single_table(self, data): instance._is_valid = Mock() instance._is_valid.return_value = {'table1': data.copy()} instance._fitted = True + instance.metadata = Mock() # Run is_valid_result = instance.is_valid(data) # Assert - instance._is_valid.assert_called_once_with(DataFrameDictMatcher({'table1': data})) + instance._is_valid.assert_called_once_with( + DataFrameDictMatcher({'table1': data}), instance.metadata + ) pd.testing.assert_frame_equal(is_valid_result, data) diff --git a/tests/unit/cag/test_fixed_combinations.py b/tests/unit/cag/test_fixed_combinations.py index 3ca9b9f34..799fa9884 100644 --- a/tests/unit/cag/test_fixed_combinations.py +++ b/tests/unit/cag/test_fixed_combinations.py @@ -581,3 +581,34 @@ def test__is_valid_with_nans(self): expected_invalid_out = pd.Series([False] * 3, name='b#c') pd.testing.assert_series_equal(expected_invalid_out, invalid_out) + + def test__is_valid_unfit(self): + """Test the ``_is_valid`` method when the constraint has not been fit.""" + # Setup + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'a': {'sdtype': 'categorical'}, + 'b': {'sdtype': 'categorical'}, + 'c': {'sdtype': 'categorical'}, + 'd': {'sdtype': 'categorical'}, + } + } + } + }) + data = pd.DataFrame({ + 'a': ['a', 'b', 'c'], + 'b': ['d', 'e', 'f'], + 'c': ['g', 'h', 'i'], + }) + + columns = ['b', 'c'] + instance = FixedCombinations(column_names=columns) + + # Run + valid_out = instance.is_valid(data, metadata) + + # Assert + expected_valid_out = pd.Series([True, True, True]) + pd.testing.assert_series_equal(expected_valid_out, valid_out) diff --git a/tests/unit/cag/test_inequality.py b/tests/unit/cag/test_inequality.py index d6c3ebf63..a96fc6039 100644 --- a/tests/unit/cag/test_inequality.py +++ b/tests/unit/cag/test_inequality.py @@ -889,16 +889,35 @@ def test__is_valid(self): 'c': [7, 8, 9, 10, 11, 12, 13, 14], }) } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'a': {'sdtype': 'numerical'}, + 'b': {'sdtype': 'numerical'}, + 'c': {'sdtype': 'numerical'}, + } + } + } + }) instance = Inequality(low_column_name='a', high_column_name='b', table_name='table') instance._fitted = True + instance.metadata = metadata + + unfit_instance = Inequality(low_column_name='a', high_column_name='b', table_name='table') # Run - out = instance.is_valid(table_data) + out_fit = instance.is_valid(table_data) + out_unfit = unfit_instance.is_valid(table_data, metadata) # Assert - out = out['table'] + out_fit = out_fit['table'] expected_out = [True, True, False, True, True, False, True, True] - np.testing.assert_array_equal(expected_out, out) + np.testing.assert_array_equal(expected_out, out_fit) + + out_unfit = out_unfit['table'] + expected_out = [True, True, False, True, True, False, True, True] + np.testing.assert_array_equal(expected_out, out_unfit) def test_is_valid_strict_boundaries_true(self): """Test it checks if the data is valid when strict boundaries are True.""" @@ -910,6 +929,17 @@ def test_is_valid_strict_boundaries_true(self): 'c': [7, 8, 9, 10, 11, 12, 13, 14], }) } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'a': {'sdtype': 'numerical'}, + 'b': {'sdtype': 'numerical'}, + 'c': {'sdtype': 'numerical'}, + } + } + } + }) instance = Inequality( low_column_name='a', high_column_name='b', @@ -917,6 +947,7 @@ def test_is_valid_strict_boundaries_true(self): table_name='table', ) instance._fitted = True + instance.metadata = metadata # Run out = instance.is_valid(table_data) @@ -936,8 +967,20 @@ def test_is_valid_datetimes(self): 'c': [7, 8, 9], }) } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'a': {'sdtype': 'datetime'}, + 'b': {'sdtype': 'datetime'}, + 'c': {'sdtype': 'numerical'}, + } + } + } + }) instance = Inequality(low_column_name='a', high_column_name='b', table_name='table') instance._fitted = True + instance.metadata = metadata # Run out = instance.is_valid(table_data) @@ -957,9 +1000,20 @@ def test_is_valid_datetime_objects(self): 'c': [7, 8, 9], }) } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'a': {'sdtype': 'datetime'}, + 'b': {'sdtype': 'datetime'}, + 'c': {'sdtype': 'numerical'}, + } + } + } + }) instance = Inequality(low_column_name='a', high_column_name='b', table_name='table') + instance.metadata = metadata instance._is_datetime = True - instance._dtype = 'O' instance._fitted = True # Run @@ -987,11 +1041,26 @@ def test_is_valid_datetimes_mismatching_datetime_formats(self, mock_match_dateti 'RANDOM_VALUE': [7, 8, 9, 10, 11], }) } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'SUBMISSION_TIMESTAMP': { + 'sdtype': 'datetime', + 'datetime_format': '%Y-%m-%d %H:%M:%S', + }, + 'DUE_DATE': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'RANDOM_VALUE': {'sdtype': 'numerical'}, + } + } + } + }) instance = Inequality( low_column_name='SUBMISSION_TIMESTAMP', high_column_name='DUE_DATE', table_name='table', ) + instance.metadata = metadata low_return = np.array([ datetime(2020, 5, 18), datetime(2020, 9, 2), @@ -1008,8 +1077,6 @@ def test_is_valid_datetimes_mismatching_datetime_formats(self, mock_match_dateti ]) instance._dtype = 'O' instance._is_datetime = True - instance._low_datetime_format = '%Y-%m-%d %H:%M:%S' - instance._high_datetime_format = '%Y-%m-%d' mock_match_datetime_precision.return_value = (low_return, high_return) instance._fitted = True diff --git a/tests/unit/cag/test_programmable_constraint.py b/tests/unit/cag/test_programmable_constraint.py index 19847c25c..5e9f9e34a 100644 --- a/tests/unit/cag/test_programmable_constraint.py +++ b/tests/unit/cag/test_programmable_constraint.py @@ -294,9 +294,10 @@ def test__is_valid(self): programmable_constraint = ProgrammableConstraint() programmable_constraint.is_valid = Mock() instance = ProgrammableConstraintHarness(programmable_constraint) + metadata = Mock() # Run - instance._is_valid(data) + instance._is_valid(data, metadata) # Assert programmable_constraint.is_valid.assert_called_once_with(data) @@ -310,9 +311,12 @@ def test__is_valid_single_table(self): programmable_constraint.is_valid.return_value = pd.Series([True] * 5) instance = ProgrammableConstraintHarness(programmable_constraint) instance._table_name = 'table' + metadata = Metadata.load_from_dict({ + 'tables': {'table': {'columns': {'col_A': {'sdtype': 'numerical'}}}} + }) # Run - is_valid = instance._is_valid(data) + is_valid = instance._is_valid(data, metadata) # Assert programmable_constraint.is_valid.assert_called_once_with(DataFrameMatcher(data['table']))