Skip to content

Improve CAG validation error message #2587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sdv/cag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pandas as pd

from sdv._utils import _format_invalid_values_string
from sdv.cag._errors import ConstraintNotMetError
from sdv.data_processing.datetime_formatter import DatetimeFormatter
from sdv.data_processing.numerical_formatter import NumericalFormatter
from sdv.errors import NotFittedError
Expand Down Expand Up @@ -51,6 +53,14 @@ def _get_single_table_name(self, metadata):

return metadata._get_single_table_name() if self.table_name is None else self.table_name

def _format_error_message_constraint(self, invalid_data, table_name):
"""Format the error message for the constraints."""
invalid_rows_str = _format_invalid_values_string(invalid_data, 5)
raise ConstraintNotMetError(
f"Data is not valid for the '{self.__class__.__name__}' constraint in "
f"table '{table_name}':\n{invalid_rows_str}"
)

def _validate_constraint_with_metadata(self, metadata):
raise NotImplementedError()

Expand Down
10 changes: 3 additions & 7 deletions sdv/cag/fixed_increments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sdv._utils import _create_unique_name
from sdv.cag._errors import ConstraintNotMetError
from sdv.cag._utils import (
_get_invalid_rows,
_get_is_valid_dict,
_remove_columns_from_metadata,
_validate_table_and_column_names,
Expand Down Expand Up @@ -112,12 +111,9 @@ def _validate_constraint_with_data(self, data, metadata):
data, self._get_single_table_name(metadata), self.column_name, self.increment_value
)
if not valid.all():
invalid_rows_str = _get_invalid_rows(valid)
raise ConstraintNotMetError(
'The fixed increments requirement has not been met because the data is not '
f"evenly divisible by '{self.increment_value}' for row indices: "
f'[{invalid_rows_str}]'
)
table_name = self._get_single_table_name(metadata)
invalid_rows = data[table_name].loc[~valid, [self.column_name]]
self._format_error_message_constraint(invalid_rows, table_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we leave the raise ConstraintNotMetError here and have _format_error_message_constraint just return the formatted string? I think it's clearer to have the error message raised here than have it hidden in the helper method. Alternatively, including raise in the method name to make it clear that this helper raises the error.


def _get_updated_metadata(self, metadata):
"""Get the updated metadata after applying the constraint to the metadata.
Expand Down
7 changes: 2 additions & 5 deletions sdv/cag/inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sdv._utils import _convert_to_timedelta, _create_unique_name
from sdv.cag._errors import ConstraintNotMetError
from sdv.cag._utils import (
_get_invalid_rows,
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
Expand Down Expand Up @@ -145,10 +144,8 @@ def _validate_constraint_with_data(self, data, metadata):

valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
if not valid.all():
invalid_rows = _get_invalid_rows(valid)
raise ConstraintNotMetError(
f'The inequality requirement is not met for row indices: [{invalid_rows}]'
)
invalid_rows = data.loc[~valid, [self._low_column_name, self._high_column_name]]
self._format_error_message_constraint(invalid_rows, table_name)

def _get_diff_and_nan_column_names(self, metadata, column_name, table_name):
"""Get the column names for the difference and NaN columns.
Expand Down
8 changes: 2 additions & 6 deletions sdv/cag/one_hot_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import numpy as np

from sdv.cag._errors import ConstraintNotMetError
from sdv.cag._utils import (
_get_invalid_rows,
_get_is_valid_dict,
_is_list_of_type,
_validate_table_and_column_names,
Expand Down Expand Up @@ -74,10 +72,8 @@ def _validate_constraint_with_data(self, data, metadata):
table_name = self._get_single_table_name(metadata)
valid = self._get_valid_table_data(data[table_name])
if not valid.all():
invalid_rows_str = _get_invalid_rows(valid)
raise ConstraintNotMetError(
f'The one hot encoding requirement is not met for row indices: [{invalid_rows_str}]'
)
invalid_rows = data[table_name].loc[~valid, self._column_names]
self._format_error_message_constraint(invalid_rows, table_name)

def _fit(self, data, metadata):
"""Fit the constraint.
Expand Down
9 changes: 4 additions & 5 deletions sdv/cag/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from sdv._utils import _convert_to_timedelta, _create_unique_name
from sdv.cag._errors import ConstraintNotMetError
from sdv.cag._utils import (
_get_invalid_rows,
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
Expand Down Expand Up @@ -169,10 +168,10 @@ def _validate_constraint_with_data(self, data, metadata):
valid = self._get_valid_table_data(data[table_name])

if not valid.all():
invalid_rows_str = _get_invalid_rows(valid)
raise ConstraintNotMetError(
f'The range requirement is not met for row indices: [{invalid_rows_str}]'
)
invalid_rows = data[table_name].loc[
~valid, [self._low_column_name, self._middle_column_name, self._high_column_name]
]
self._format_error_message_constraint(invalid_rows, table_name)

def _get_diff_and_nan_column_names(self, metadata, table_name):
"""Create unique names for the low, high, and nan component columns."""
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/cag/test_inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,10 @@ def test_invalid_data():
clean_data = data[~(data[['checkin_date', 'checkout_date']].isna().any(axis=1))]
data_invalid = clean_data.copy()
data_invalid.loc[0, 'checkin_date'] = '31 Dec 2020'
expected_error_msg = re.escape('The inequality requirement is not met for row indices: [0]')
expected_error_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'fake_hotel_guests':\n"
' checkin_date checkout_date\n0 31 Dec 2020 29 Dec 2020'
)

# Run and Assert
synthesizer = run_copula(clean_data, metadata, [constraint])
Expand Down
10 changes: 8 additions & 2 deletions tests/integration/cag/test_one_hot_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def test_end_to_end_raises(data, metadata):
})

# Run and Assert
msg = re.escape('The one hot encoding requirement is not met for row indices: [1, 2]')
msg = re.escape(
"Data is not valid for the 'OneHotEncoding' constraint in table 'table':\n"
' a b c\n1 2 1.0 0\n2 0 NaN 3'
)
with pytest.raises(ConstraintNotMetError, match=msg):
run_copula(invalid_data, metadata, [OneHotEncoding(column_names=['a', 'b', 'c'])])

Expand Down Expand Up @@ -124,7 +127,10 @@ def test_end_to_end_multi_raises(data_multi, metadata_multi):
}

# Run and Assert
msg = re.escape('The one hot encoding requirement is not met for row indices: [1, 2]')
msg = re.escape(
"Data is not valid for the 'OneHotEncoding' constraint in table 'table1':\n "
'a b c\n1 2 1.0 0\n2 0 NaN 3'
)
with pytest.raises(ConstraintNotMetError, match=msg):
run_hma(invalid_data, metadata_multi, [constraint])

Expand Down
19 changes: 19 additions & 0 deletions tests/integration/cag/test_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,25 @@ def test_validate_constraints_raises(data, metadata, constraint):
synthesizer.validate_constraints(synthetic_data=synthetic_data)


def test_invalid_data(data, metadata, constraint):
"""Test validate_constraints raises an error with bad synthetic data with Range."""
# Setup
invalid_data = pd.DataFrame({
'A': data['B'],
'B': data['A'],
'C': data['C'],
})
msg = re.escape(
"Data is not valid for the 'Range' constraint in table 'table':\n"
' A B C\n0 10 1 100\n1 20 2 200\n2 30 3 300\n3 10'
' 1 100\n4 20 2 200\n+1 more'
)

# Run and Assert
with pytest.raises(ConstraintNotMetError, match=msg):
run_copula(invalid_data, metadata, [constraint])


def test_validate_constraints_multi(data_multi, metadata_multi, constraint_multi):
"""Test validate_constraints with data generated with Range with multitable numerical data."""
synthesizer = run_hma(data_multi, metadata_multi, [constraint_multi])
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2662,7 +2662,10 @@ def test_end_to_end_with_cags():
data['guests'] = clean_data
invalid_data = data.copy()
invalid_data['guests'] = data_invalid
expected_error_msg = re.escape('The inequality requirement is not met for row indices: [0]')
expected_error_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'guests':\n amenities_lower"
' amenities_fee\n0 38.89 37.89'
)

# Run
synthesizer.fit(data)
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,10 @@ def test_validate_with_failing_constraint():
low_column_name='checkin_date', high_column_name='checkout_date'
)
gc.add_constraints([checkin_lessthan_checkout])

error_msg = re.escape('The inequality requirement is not met for row indices: [0]')
error_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'fake_hotel_guests':\n"
' checkin_date checkout_date\n0 02 Jan 2021 29 Dec 2020'
)

# Run / Assert
with pytest.raises(ConstraintNotMetError, match=error_msg):
Expand Down
25 changes: 24 additions & 1 deletion tests/unit/cag/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import logging
import re
from copy import deepcopy
from unittest.mock import Mock
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest

from sdv.cag._errors import ConstraintNotMetError
from sdv.cag.base import BaseConstraint
from sdv.data_processing.datetime_formatter import DatetimeFormatter
from sdv.data_processing.numerical_formatter import NumericalFormatter
Expand Down Expand Up @@ -192,6 +193,28 @@ def test_get_updated_metadata(self):
instance.validate.assert_called_once_with(metadata=metadata)
instance._get_updated_metadata.assert_called_once()

@patch('sdv.cag.base._format_invalid_values_string')
def test__format_error_message_constraint(self, mock_format_invalid_values_string):
"""Test `_format_error_message_constraint` method."""
# Setup
invalid_data = {'row_1': 'value_1', 'row_2': 'value_2'}
constraint = BaseConstraint()
table_name = 'test_table'
mock_format_invalid_values_string.return_value = re.escape(
'checkin_date checkout_date\n0 31 Dec 2020 29 Dec 2020'
)
expected_error_message = re.escape(
"Data is not valid for the 'BaseConstraint' constraint in table "
"'test_table':\ncheckin_date\\ checkout_date\\\n0\\ \\ 31\\ Dec\\ 2020\\"
' \\ \\ 29\\ Dec\\ 2020'
)

# Run and Assert
with pytest.raises(ConstraintNotMetError, match=expected_error_message):
constraint._format_error_message_constraint(invalid_data, table_name)

mock_format_invalid_values_string.assert_called_once_with(invalid_data, 5)

def test__fit_constraint_column_formatters(self):
"""Test the `_fit_constraint_column_formatters` fits formatters for dropped columns."""
# Setup
Expand Down
18 changes: 10 additions & 8 deletions tests/unit/cag/test_fixed_increments.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,16 @@ def test__validate_constraint_with_data(self, values):
instance = FixedIncrements(
column_name=column_name, increment_value=increment_value, table_name=table_name
)
indices = data[table_name].index.tolist()
if len(indices) > 5:
indices = '[0, 1, 2, 3, 4, +2 more]'
err_msg = re.escape(
'The fixed increments requirement has not been met because the data is not '
f"evenly divisible by '{increment_value}' "
f'for row indices: {indices}'
)
if len(data[table_name]) > 5:
err_msg = re.escape(
"Data is not valid for the 'FixedIncrements' constraint in table 'table1':\n "
'odd\n0 1\n1 3\n2 5\n3 7\n4 9\n+2 more'
)
else:
err_msg = re.escape(
"Data is not valid for the 'FixedIncrements' constraint in table 'table1':\n "
'odd\n0 1\n1 3\n2 5\n3 7'
)

# Run and Assert
with pytest.raises(ConstraintNotMetError, match=err_msg):
Expand Down
34 changes: 27 additions & 7 deletions tests/unit/cag/test_inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def test__validate_constraint_with_data(self):
})

# Run and Assert
err_msg = re.escape('The inequality requirement is not met for row indices: [1]')
err_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'table':\n low "
'high\n1 2 0'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

Expand Down Expand Up @@ -195,7 +198,8 @@ def test__validate_constraint_with_data_multiple_rows(self):

# Run and Assert
err_msg = re.escape(
'The inequality requirement is not met for row indices: [1, 4, 6, 7, 8, +1 more]'
"Data is not valid for the 'Inequality' constraint in table 'table':\n low "
'high\n1 2 0\n4 5 0\n6 7 4\n7 8 0\n8 9 6\n+1 more'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)
Expand Down Expand Up @@ -227,7 +231,10 @@ def test__validate_constraint_with_data_nans(self):
})

# Run and Assert
err_msg = re.escape('The inequality requirement is not met for row indices: [2, 5]')
err_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'table':\n"
' low high\n2 3.0 2.0\n5 6.0 -6.0'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

Expand Down Expand Up @@ -262,7 +269,10 @@ def test__validate_constraint_with_data_strict_boundaries_true(self):
})

# Run and Assert
err_msg = re.escape('The inequality requirement is not met for row indices: [2, 3, 5]')
err_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'table':\n low"
' high\n2 3.0 2.0\n3 4.0 4.0\n5 6.0 -6.0'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

Expand Down Expand Up @@ -297,7 +307,10 @@ def test__validate_constraint_with_data_datetime(self):
})

# Run and Assert
err_msg = re.escape('The inequality requirement is not met for row indices: [1]')
err_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'table':\n "
' low high\n1 2021-09-01 2020-09-02'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

Expand Down Expand Up @@ -332,7 +345,10 @@ def test__validate_constraint_with_data_datetime_objects(self):
})

# Run and Assert
err_msg = re.escape('The inequality requirement is not met for row indices: [1]')
err_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'table':\n"
' low high\n1 2021-9-1 2020-9-2'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

Expand Down Expand Up @@ -397,7 +413,11 @@ def test__validate_constraint_with_data_datetime_objects_mismatching_formats(
})

# Run and Assert
err_msg = re.escape('The inequality requirement is not met for row indices: [0, 2]')
err_msg = re.escape(
"Data is not valid for the 'Inequality' constraint in table 'table':\n "
' low high\n0 2016-07-10 17:04:00 2016-07-10\n2 '
'2016-07-12 08:45:30 2016-07-12'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

Expand Down
20 changes: 16 additions & 4 deletions tests/unit/cag/test_one_hot_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,37 @@ def test__validate_constraint_with_data(self):

# Row of all zeros
data = {'table': pd.DataFrame({'a': [1, 0, 0], 'b': [0, 0, 0], 'c': [0, 0, 1]})}
err_msg = re.escape('The one hot encoding requirement is not met for row indices: [1]')
err_msg = re.escape(
"Data is not valid for the 'OneHotEncoding' constraint in table 'table':\n "
' a b c\n1 0 0 0'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

# Row with two 1s
data = {'table': pd.DataFrame({'a': [1, 0, 0], 'b': [0, 1, 1], 'c': [1, 0, 0]})}
err_msg = re.escape('The one hot encoding requirement is not met for row indices: [0]')
err_msg = re.escape(
"Data is not valid for the 'OneHotEncoding' constraint in table 'table':\n"
' a b c\n0 1 0 1'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

# Invalid number
data = {'table': pd.DataFrame({'a': [1, 0, 0], 'b': [0, 2, 0], 'c': [0, 0, 1]})}
err_msg = re.escape('The one hot encoding requirement is not met for row indices: [1]')
err_msg = re.escape(
"Data is not valid for the 'OneHotEncoding' constraint in table 'table':\n"
' a b c\n1 0 2 0'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

# Nans
data = {'table': pd.DataFrame({'a': [1, 0, 0], 'b': [0, 1, np.nan], 'c': [0, None, 1]})}
err_msg = re.escape('The one hot encoding requirement is not met for row indices: [1, 2]')
err_msg = re.escape(
"Data is not valid for the 'OneHotEncoding' constraint in table 'table':\n"
' a b c\n1 0 1.0 NaN\n2 0 NaN 1.0'
)
with pytest.raises(ConstraintNotMetError, match=err_msg):
instance._validate_constraint_with_data(data, metadata)

Expand Down
Loading
Loading