diff --git a/sdv/evaluation/single_table.py b/sdv/evaluation/single_table.py index 950a72592..6adecfd55 100644 --- a/sdv/evaluation/single_table.py +++ b/sdv/evaluation/single_table.py @@ -167,3 +167,103 @@ def get_column_pair_plot( ) return visualization.get_column_pair_plot(real_data, synthetic_data, column_names, plot_type) + +from sdv.datasets.demo import download_demo, get_available_demos + +def plot_learning_curve(real_data, synthetic_generator, metadata, train_sizes=None, n_splits=5, + random_state=None): + """Plot learning curve showing how synthetic data quality varies with training data size. + + Args: + real_data (pandas.DataFrame): + The complete real dataset to use for generating learning curves. + synthetic_generator (BaseGenerator): + An instance of a synthetic data generator with a sample() method that takes + a num_rows parameter. The generator should already be fitted on the appropriate data. + metadata (Metadata): + The metadata object describing the real/synthetic data structure. + train_sizes (array-like or None): + List of floats between 0.0 and 1.0 representing training set sizes to evaluate. + If None, defaults to np.linspace(0.1, 1.0, 5). Defaults to None. + n_splits (int): + Number of times to repeat evaluation for each training size to compute confidence + intervals. Defaults to 5. + random_state (int or None): + Random seed for reproducibility. Defaults to None. + + Returns: + plotly.graph_objects._figure.Figure: + Interactive plot showing learning curves with confidence intervals. + """ + #import numpy as np + #import plotly.graph_objects as go + #from sklearn.model_selection import ShuffleSplit + + # Set default training sizes if none provided + if train_sizes is None: + train_sizes = np.linspace(0.1, 1.0, 5) + + # Initialize arrays to store scores + n_sizes = len(train_sizes) + scores = np.zeros((n_splits, n_sizes)) + + # Create cross-validation splits for each training size + cv = ShuffleSplit(n_splits=n_splits, random_state=random_state) + + # For each training size + for size_idx, train_size in enumerate(train_sizes): + # Calculate actual number of samples for this training size + n_samples = int(train_size * len(real_data)) + + # For each CV split + for split_idx, (train_idx, _) in enumerate(cv.split(real_data)): + # Sample training data + train_data = real_data.iloc[train_idx[:n_samples]] + + # Generate synthetic data using the same number of rows as training data + synthetic_data = synthetic_generator.sample(num_rows=len(train_data)) + + # Evaluate quality + quality_report = evaluate_quality(train_data, synthetic_data, metadata, verbose=False) + + # Store score + scores[split_idx, size_idx] = quality_report.get_score() + + # Calculate mean and std of scores + mean_scores = np.mean(scores, axis=0) + std_scores = np.std(scores, axis=0) + + # Create plot + fig = go.Figure() + + # Add mean line + fig.add_trace(go.Scatter( + x=train_sizes, + y=mean_scores, + name='Quality Score', + line=dict(color='blue'), + mode='lines+markers' + )) + + # Add confidence interval + fig.add_trace(go.Scatter( + x=np.concatenate([train_sizes, train_sizes[::-1]]), + y=np.concatenate([mean_scores + std_scores, (mean_scores - std_scores)[::-1]]), + fill='toself', + fillcolor='blue', + opacity=0.2, + line=dict(color='rgba(0,0,0,0)'), + showlegend=False, + name='Quality Score (±1 std)' + )) + + # Update layout + fig.update_layout( + title='Synthetic Data Quality Learning Curve', + xaxis_title='Training Set Size (fraction of full dataset)', + yaxis_title='Quality Score', + hovermode='x unified', + template='plotly_white' + ) + + return fig \ No newline at end of file diff --git a/tests/integration/evaluation/test_single_table.py b/tests/integration/evaluation/test_single_table.py index d943b02d3..9be1954df 100644 --- a/tests/integration/evaluation/test_single_table.py +++ b/tests/integration/evaluation/test_single_table.py @@ -1,10 +1,10 @@ import pandas as pd from sdv.datasets.demo import download_demo -from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic +from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic, plot_learning_curve from sdv.metadata.metadata import Metadata from sdv.single_table.copulas import GaussianCopulaSynthesizer - +from sdv.single_table.ctgan import CTGANSynthesizer def test_evaluation(): """Test ``evaluate_quality`` and ``run_diagnostic``.""" @@ -60,7 +60,7 @@ def test_evaluation_metadata(): def test_column_pair_plot_sample_size_parameter(): """Test the sample_size parameter for the column pair plot.""" # Setup - real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') + real_data, metadata = download_demo(modality='single_table', dataset_name='expedia_hotel_logs') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(real_data) synthetic_data = synthesizer.sample(len(real_data)) @@ -78,3 +78,13 @@ def test_column_pair_plot_sample_size_parameter(): assert len(synthetic_data) == 500 assert len(fig.data[0].x) == 40 assert len(fig.data[1].x) == 40 + +def test_plot_learning_curve(): + """Test the plot_learning_curves function.""" + # Setup + real_data, metadata = download_demo(modality='single_table', dataset_name='asia') + synthesizer = GaussianCopulaSynthesizer(metadata) + synthesizer.fit(real_data) + + learning_curve = plot_learning_curve(real_data, synthesizer, metadata, train_sizes=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) + learning_curve.show() \ No newline at end of file