Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
cb3c502
Add FixedCombinations CAG pattern + add CAG base class to public (#2400)
frances-h Feb 28, 2025
0031b8f
Add FixedIncrements CAG (#2409)
gsheni Mar 10, 2025
0c89d28
Add Inequality CAG (#2405)
fealho Mar 12, 2025
bec4427
Add Range CAG (#2407)
fealho Mar 12, 2025
5fb5121
Add OneHotEncoding CAG (#2414)
fealho Mar 12, 2025
78f5540
Store metadata as `Metadata` for `BaseSynthesizer` (#2422)
fealho Mar 19, 2025
1aedc37
Add CAG support to single table synthesizers (#2419)
fealho Mar 26, 2025
c7e7dae
Make single table CAGs backwards compatible (#2439)
fealho Mar 31, 2025
d646f85
Add deprecation warning to ScalarInequality and ScalarRange (#2442)
lajohn4747 Apr 1, 2025
b80871a
fix _validate_regex_format after rebasing
R-Palazzo Apr 10, 2025
516df0d
Draft PR to move cag base multi-table logic to SDV public (#2481)
R-Palazzo Apr 28, 2025
dea1a39
Add version parameter to SingleTableSynthesizer (#2503)
pvk-developer May 6, 2025
489590b
Add method to validate CAG pattern for synthetic data (#2485)
gsheni May 7, 2025
e6122c7
Inequality CAG errors out if data contains NaN values (#2507)
R-Palazzo May 8, 2025
7347cc6
Inequality CAG does not respect datetime format (#2506)
frances-h May 9, 2025
e7fed1c
Constraint hits IntCastingNanError when reverse transforming int colu…
frances-h May 12, 2025
054a4e8
Add `ProgrammableConstraint` and `ProgrammableSingleTableConstraint` …
frances-h May 13, 2025
0a5f98e
Add CAG validation to synthesizer.validate (#2480)
R-Palazzo May 14, 2025
851b72a
`auto_assign_transformers` errors after adding CAG pattern (#2511)
R-Palazzo May 14, 2025
4b3aa09
`ValueError` if conditionally sampling on a column dropped by constra…
frances-h May 14, 2025
ec688f5
Incorrect formatting when applying `Inequality` constraint (#2528)
frances-h May 15, 2025
5c2786f
Allow fit to be an optional method for programmable constraint (#2529)
pvk-developer May 15, 2025
25e0bea
Enable single-table constraint reject sampling with multi-table synth…
R-Palazzo May 16, 2025
8437d50
Add FixedCombinations CAG pattern + add CAG base class to public (#2400)
frances-h Feb 28, 2025
2c55d49
Add FixedIncrements CAG (#2409)
gsheni Mar 10, 2025
70a5f78
Add Inequality CAG (#2405)
fealho Mar 12, 2025
9a9ab83
Add Range CAG (#2407)
fealho Mar 12, 2025
717df35
Add OneHotEncoding CAG (#2414)
fealho Mar 12, 2025
c123985
Store metadata as `Metadata` for `BaseSynthesizer` (#2422)
fealho Mar 19, 2025
8c9eaa2
Add CAG support to single table synthesizers (#2419)
fealho Mar 26, 2025
35c16cf
Make single table CAGs backwards compatible (#2439)
fealho Mar 31, 2025
2fee1d1
Add deprecation warning to ScalarInequality and ScalarRange (#2442)
lajohn4747 Apr 1, 2025
b3d3362
fix _validate_regex_format after rebasing
R-Palazzo Apr 10, 2025
abfedf9
Draft PR to move cag base multi-table logic to SDV public (#2481)
R-Palazzo Apr 28, 2025
3de86d4
Add version parameter to SingleTableSynthesizer (#2503)
pvk-developer May 6, 2025
52f38c6
Add method to validate CAG pattern for synthetic data (#2485)
gsheni May 7, 2025
8d17594
Inequality CAG errors out if data contains NaN values (#2507)
R-Palazzo May 8, 2025
01bce6c
Inequality CAG does not respect datetime format (#2506)
frances-h May 9, 2025
ee68fc8
Constraint hits IntCastingNanError when reverse transforming int colu…
frances-h May 12, 2025
8a90f47
Add `ProgrammableConstraint` and `ProgrammableSingleTableConstraint` …
frances-h May 13, 2025
002c39e
Add CAG validation to synthesizer.validate (#2480)
R-Palazzo May 14, 2025
1a9050d
`auto_assign_transformers` errors after adding CAG pattern (#2511)
R-Palazzo May 14, 2025
7dade32
`ValueError` if conditionally sampling on a column dropped by constra…
frances-h May 14, 2025
647bbf2
Incorrect formatting when applying `Inequality` constraint (#2528)
frances-h May 15, 2025
6f495a9
Allow fit to be an optional method for programmable constraint (#2529)
pvk-developer May 15, 2025
f3e05bd
Enable single-table constraint reject sampling with multi-table synth…
R-Palazzo May 16, 2025
554bbda
Merge branch 'feature/single-table-CAG' of https://github.com/sdv-dev…
fealho May 20, 2025
a851467
Evaluate and improve CAG pattern testing coverage (#2518)
fealho May 21, 2025
cc71660
Merge branch 'feature/single-table-CAG' of https://github.com/sdv-dev…
fealho May 21, 2025
4e92984
Improve testing suit
fealho May 13, 2025
3b88c65
Fix rebase
fealho May 19, 2025
c128250
Add failing tests
fealho May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions sdv/cag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""SDV CAG module."""

from sdv.cag.fixed_combinations import FixedCombinations
from sdv.cag.fixed_increments import FixedIncrements
from sdv.cag.inequality import Inequality
from sdv.cag.range import Range
from sdv.cag.one_hot_encoding import OneHotEncoding
from sdv.cag.programmable_constraint import (
ProgrammableConstraint,
SingleTableProgrammableConstraint,
)

__all__ = (
'FixedCombinations',
'FixedIncrements',
'Inequality',
'Range',
'OneHotEncoding',
'ProgrammableConstraint',
'SingleTableProgrammableConstraint',
)
5 changes: 5 additions & 0 deletions sdv/cag/_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""CAG Pattern Exceptions."""


class PatternNotMetError(Exception):
"""Error to raise when a CAG pattern is not met."""
149 changes: 149 additions & 0 deletions sdv/cag/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import re

import numpy as np
import pandas as pd

from sdv.cag._errors import PatternNotMetError
from sdv.metadata import Metadata


def _validate_columns_in_metadata(table_name, columns, metadata):
"""Validates that the columns are in the metadata.

Args:
table_name (str):
The name of the table in the metadata.

columns (list[str])
The column names to check.

metadata (sdv.metadata.Metadata):
The Metadata to check.
"""
if not set(columns).issubset(set(metadata.tables[table_name].columns)):
missing_columns = set(columns) - set(metadata.tables[table_name].columns)
missing_columns = "', '".join(sorted(missing_columns))
raise PatternNotMetError(f"Table '{table_name}' is missing columns '{missing_columns}'.")


def _validate_table_and_column_names(table_name, columns, metadata):
"""Validate the table name and columns against the metadata.

It checks the following:
- If the table name is None, the metadata should only contain a single table.
- The table name is in the metadata.
- The columns are in the metadata.

Args:
table_name (str):
The name of the table in the metadata to validate.

columns (list[str])
The column names to check.

metadata (sdv.metadata.Metadata):
The Metadata to check.
"""
if table_name is None and len(metadata.tables) > 1:
raise PatternNotMetError(
'Metadata contains more than 1 table but no ``table_name`` provided.'
)
if table_name is None:
table_name = metadata._get_single_table_name()
elif table_name not in metadata.tables:
raise PatternNotMetError(f"Table '{table_name}' missing from metadata.")

_validate_columns_in_metadata(table_name, columns, metadata)


def _validate_table_name_if_defined(table_name):
"""Validate if the table name is defined, it is a string."""
if table_name and not isinstance(table_name, str):
raise ValueError('`table_name` must be a string or None.')


def _is_list_of_type(values, type_to_check=str):
"""Checks that 'values' is a list and all elements are of type 'type_to_check'."""
return isinstance(values, list) and all(isinstance(value, type_to_check) for value in values)


def _get_invalid_rows(valid):
"""Determine the indices of the rows where value is False.

Args:
valid (pd.Series):
The input data to check for False values.

Returns:
(str): A string that describes the indices where the value is False.
If there are more than 5 indices, the rest are described as 'more'.
"""
invalid_rows = np.where(~valid)[0]
if len(invalid_rows) <= 5:
invalid_rows_str = ', '.join(str(i) for i in invalid_rows)
else:
first_five = ', '.join(str(i) for i in invalid_rows[:5])
remaining = len(invalid_rows) - 5
invalid_rows_str = f'{first_five}, +{remaining} more'
return invalid_rows_str


def _get_is_valid_dict(data, table_name):
"""Create a dictionary of True values for each table besides table_name.

Besides table_name, all rows of every other table are considered valid,
so the boolean Series will be True for all rows of every other table.

Args:
data (dict):
The data.
table_name (str):
The name of the table to exclude from the dictionary.

Returns:
dict:
Dictionary of table names to boolean Series of True values.
"""
return {
table: pd.Series(True, index=table_data.index)
for table, table_data in data.items()
if table != table_name
}


def _convert_to_snake_case(string):
"""Convert a string to snake case (words separated by underscores, all lowercase)."""
return re.sub(r'([a-z])([A-Z])', r'\1_\2', string).lower()


def _remove_columns_from_metadata(metadata, table_name, columns_to_drop):
"""Remove columns from metadata, including column relationships.

Will raise an error if the primary key is being dropped.

Args:
metadata (dict, sdv.metadata.Metadata): The Metadata which contains
the columns to drop.
table_name (str): Name of the table in the metadata, where the column(s)
are located.
columns_to_drop (list[str]): The list of column names to drop from the
Metadata.

Returns:
(sdv.metadata.Metadata): The new Metadata, with the columns removed.
"""
if isinstance(metadata, Metadata):
metadata = metadata.to_dict()
column_set = set(columns_to_drop)
primary_key = metadata['tables'][table_name].get('primary_key')
for column in column_set:
if primary_key and primary_key == column:
raise ValueError('Cannot remove primary key from Metadata')
del metadata['tables'][table_name]['columns'][column]

metadata['tables'][table_name]['column_relationships'] = [
rel
for rel in metadata['tables'][table_name].get('column_relationships', [])
if set(rel['column_names']).isdisjoint(column_set)
]
return Metadata.load_from_dict(metadata)
220 changes: 220 additions & 0 deletions sdv/cag/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Base CAG constraint pattern."""

import logging

import numpy as np
import pandas as pd

from sdv.errors import NotFittedError

LOGGER = logging.getLogger(__name__)


class BasePattern:
"""Base CAG Pattern Class."""

_is_single_table = True

def __init__(self):
self.metadata = None
self._fitted = False
self._single_table = False

def _convert_data_to_dictionary(self, data, metadata, copy=False):
"""Helper to handle converting single dataframes into dictionaries.

This method takes in data, metadata, and, optionally, a flag indiciating if the
returned data should be a copy of the original input data. If the data is a single
dataframe, it converts it into a dictionary of dataframes.
"""
if isinstance(data, pd.DataFrame):
if copy:
data = data.copy()

if self._single_table:
data = {self._table_name: data}
else:
table_name = self._get_single_table_name(metadata)
data = {table_name: data}
elif copy:
data = {table_name: table_data.copy() for table_name, table_data in data.items()}

return data

def _get_single_table_name(self, metadata):
if not hasattr(self, 'table_name'):
raise ValueError('No ``table_name`` attribute has been set.')

return metadata._get_single_table_name() if self.table_name is None else self.table_name

def _validate_pattern_with_metadata(self, metadata):
raise NotImplementedError()

def _validate_pattern_with_data(self, data, metadata):
raise NotImplementedError()

def validate(self, data=None, metadata=None):
"""Validate the data/metadata meets the pattern requirements.

Args:
data (dict[str, pd.DataFrame], optional)
The data dictionary. If `None`, ``validate`` will skip data validation.
metadata (sdv.Metadata, optional)
The input metadata. If `None`, pattern must have been fitted and ``validate``
will use the metadata saved during fitting.
"""
if metadata is None:
if self.metadata is None:
raise NotFittedError('Pattern must be fit before validating without metadata.')

metadata = self.metadata

self._validate_pattern_with_metadata(metadata)

if data is not None:
data = self._convert_data_to_dictionary(data, metadata)
self._validate_pattern_with_data(data, metadata)

def _get_updated_metadata(self, metadata):
return metadata

def get_updated_metadata(self, metadata):
"""Get the updated metadata after applying the pattern to the input metadata.

Args:
metadata (sdv.Metadata):
The input metadata to apply the pattern to.
"""
self.validate(metadata=metadata)
return self._get_updated_metadata(metadata)

def _fit(self, data, metadata):
raise NotImplementedError

def fit(self, data, metadata):
"""Fit the pattern with data and metadata.

Args:
data (dict[pd.DataFrame]):
The data dictionary to fit the pattern on.
metadata (sdv.Metadata):
The metadata to fit the pattern on.
"""
self._validate_pattern_with_metadata(metadata)
if isinstance(data, pd.DataFrame):
self._single_table = True
self._table_name = self._get_single_table_name(metadata)
data = self._convert_data_to_dictionary(data, metadata)

self._validate_pattern_with_data(data, metadata)
self._fit(data, metadata)
self.metadata = metadata

self._dtypes = {table: data[table].dtypes.to_dict() for table in metadata.tables}
self._original_data_columns = {
table: metadata.tables[table].get_column_names() for table in metadata.tables
}
self._fitted = True

def _transform(self, data):
raise NotImplementedError

def transform(self, data):
"""Transform the data.

Args:
data (dict[str, pd.DataFrame])
The input data dictionary to be transformed.
"""
if not self._fitted:
raise NotFittedError('Pattern must be fit using ``fit`` before transforming.')

self.validate(data)
data = self._convert_data_to_dictionary(data, self.metadata, copy=True)
transformed_data = self._transform(data)
if self._single_table:
return transformed_data[self._table_name]

return transformed_data

def _reverse_transform(self, data):
raise NotImplementedError

def _table_as_type_by_col(self, reverse_transformed, table, table_name):
"""Cast table to given types on a column by column basis.

Args:
reverse_transformed (dict[str, pd.DataFrame])
The reverse transformed data dictionary
table (pd.DataFrame)
The reverse transformed table
table_name (str)
The name of the table
"""
for col in table:
try:
reverse_transformed[table_name][col] = table[col].astype(
self._dtypes[table_name][col]
)
except pd.errors.IntCastingNaNError:
LOGGER.info(
"Column '%s' is being converted to float because it contains NaNs.", col
)
self._dtypes[table_name][col] = np.dtype('float64')
reverse_transformed[table_name][col] = table[col].astype(
self._dtypes[table_name][col]
)

def reverse_transform(self, data):
"""Reverse transform the data back into the original space.

Args:
data (dict[str, pd.DataFrame])
The transformed data dictionary to be reverse transformed.
"""
data = self._convert_data_to_dictionary(data, self.metadata, copy=True)
reverse_transformed = self._reverse_transform(data)
for table_name, table in reverse_transformed.items():
valid_columns = [
column
for column in self._original_data_columns[table_name]
if column in table.columns
]
dtypes = {col: self._dtypes[table_name][col] for col in valid_columns}
table = table[valid_columns]
try:
reverse_transformed[table_name] = table.astype(dtypes)
except pd.errors.IntCastingNaNError:
# iterate over the columns and cast individually
self._table_as_type_by_col(reverse_transformed, table, table_name)

if self._single_table:
return reverse_transformed[self._table_name]

return reverse_transformed

def _is_valid(self, data):
raise NotImplementedError

def is_valid(self, data):
"""Say whether the given table rows are valid.

Args:
data (pd.DataFrame or dict[pd.DataFrame]):
Table data.

Returns:
pd.Series or dict[pd.Series]:
Series of boolean values indicating if the row is valid for the pattern or not.
"""
if not self._fitted:
raise NotFittedError(
'Pattern must be fit using ``fit`` before determining if data is valid.'
)

data = self._convert_data_to_dictionary(data, self.metadata)
is_valid_data = self._is_valid(data)
if self._single_table:
return is_valid_data[self._table_name]

return is_valid_data
Loading
Loading