Skip to content

Learning Curve Plot for Evaluation #2354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions sdv/evaluation/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,103 @@ def get_column_pair_plot(
)

return visualization.get_column_pair_plot(real_data, synthetic_data, column_names, plot_type)

from sdv.datasets.demo import download_demo, get_available_demos

def plot_learning_curve(real_data, synthetic_generator, metadata, train_sizes=None, n_splits=5,
random_state=None):
"""Plot learning curve showing how synthetic data quality varies with training data size.

Args:
real_data (pandas.DataFrame):
The complete real dataset to use for generating learning curves.
synthetic_generator (BaseGenerator):
An instance of a synthetic data generator with a sample() method that takes
a num_rows parameter. The generator should already be fitted on the appropriate data.
metadata (Metadata):
The metadata object describing the real/synthetic data structure.
train_sizes (array-like or None):
List of floats between 0.0 and 1.0 representing training set sizes to evaluate.
If None, defaults to np.linspace(0.1, 1.0, 5). Defaults to None.
n_splits (int):
Number of times to repeat evaluation for each training size to compute confidence
intervals. Defaults to 5.
random_state (int or None):
Random seed for reproducibility. Defaults to None.

Returns:
plotly.graph_objects._figure.Figure:
Interactive plot showing learning curves with confidence intervals.
"""
#import numpy as np
#import plotly.graph_objects as go
#from sklearn.model_selection import ShuffleSplit

# Set default training sizes if none provided
if train_sizes is None:
train_sizes = np.linspace(0.1, 1.0, 5)

# Initialize arrays to store scores
n_sizes = len(train_sizes)
scores = np.zeros((n_splits, n_sizes))

# Create cross-validation splits for each training size
cv = ShuffleSplit(n_splits=n_splits, random_state=random_state)

# For each training size
for size_idx, train_size in enumerate(train_sizes):
# Calculate actual number of samples for this training size
n_samples = int(train_size * len(real_data))

# For each CV split
for split_idx, (train_idx, _) in enumerate(cv.split(real_data)):
# Sample training data
train_data = real_data.iloc[train_idx[:n_samples]]

# Generate synthetic data using the same number of rows as training data
synthetic_data = synthetic_generator.sample(num_rows=len(train_data))

# Evaluate quality
quality_report = evaluate_quality(train_data, synthetic_data, metadata, verbose=False)

# Store score
scores[split_idx, size_idx] = quality_report.get_score()

# Calculate mean and std of scores
mean_scores = np.mean(scores, axis=0)
std_scores = np.std(scores, axis=0)

# Create plot
fig = go.Figure()

# Add mean line
fig.add_trace(go.Scatter(
x=train_sizes,
y=mean_scores,
name='Quality Score',
line=dict(color='blue'),
mode='lines+markers'
))

# Add confidence interval
fig.add_trace(go.Scatter(
x=np.concatenate([train_sizes, train_sizes[::-1]]),
y=np.concatenate([mean_scores + std_scores, (mean_scores - std_scores)[::-1]]),
fill='toself',
fillcolor='blue',
opacity=0.2,
line=dict(color='rgba(0,0,0,0)'),
showlegend=False,
name='Quality Score (±1 std)'
))

# Update layout
fig.update_layout(
title='Synthetic Data Quality Learning Curve',
xaxis_title='Training Set Size (fraction of full dataset)',
yaxis_title='Quality Score',
hovermode='x unified',
template='plotly_white'
)

return fig
16 changes: 13 additions & 3 deletions tests/integration/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd

from sdv.datasets.demo import download_demo
from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic
from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic, plot_learning_curve
from sdv.metadata.metadata import Metadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

from sdv.single_table.ctgan import CTGANSynthesizer

def test_evaluation():
"""Test ``evaluate_quality`` and ``run_diagnostic``."""
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_evaluation_metadata():
def test_column_pair_plot_sample_size_parameter():
"""Test the sample_size parameter for the column pair plot."""
# Setup
real_data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests')
real_data, metadata = download_demo(modality='single_table', dataset_name='expedia_hotel_logs')
synthesizer = GaussianCopulaSynthesizer(metadata)
synthesizer.fit(real_data)
synthetic_data = synthesizer.sample(len(real_data))
Expand All @@ -78,3 +78,13 @@ def test_column_pair_plot_sample_size_parameter():
assert len(synthetic_data) == 500
assert len(fig.data[0].x) == 40
assert len(fig.data[1].x) == 40

def test_plot_learning_curve():
"""Test the plot_learning_curves function."""
# Setup
real_data, metadata = download_demo(modality='single_table', dataset_name='asia')
synthesizer = GaussianCopulaSynthesizer(metadata)
synthesizer.fit(real_data)

learning_curve = plot_learning_curve(real_data, synthesizer, metadata, train_sizes=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
learning_curve.show()