diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index c3a5dcea..ac1ea7e3 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -2,9 +2,12 @@ import os import jsonschema +import numpy as np +import pandas as pd +from counterplots import CreatePlot from raiutils.exceptions import UserConfigValidationException -from dice_ml.constants import _SchemaVersions +from dice_ml.constants import BackEndTypes, _SchemaVersions from dice_ml.diverse_counterfactuals import (CounterfactualExamples, _DiverseCFV2SchemaConstants) @@ -111,6 +114,55 @@ def visualize_as_list(self, display_sparse_df=True, display_sparse_df=display_sparse_df, show_only_changes=show_only_changes) + def plot_counterplots(self, dice_model): + """Plot counterfactual with CounterPlots package. + + :param dice_model: DiCE's model object. + """ + counterplots_out = [] + for cf_examples in self.cf_examples_list: + self.features_names = list(cf_examples.test_instance_df.columns)[:-1] + self.features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1] + factual_instance = cf_examples.test_instance_df.to_numpy()[0][:-1] + + def convert_data(x): + df_x = pd.DataFrame(data=x, columns=self.features_names) + # Transform each dtype according to features_dtypes + for feature_name, f_dtype in zip(self.features_names, self.features_dtypes): + df_x[feature_name] = pd.to_numeric(df_x[feature_name], errors='ignore').astype(f_dtype) + + return df_x + + if dice_model.backend == BackEndTypes.Sklearn: + self.factual_class_idx = np.argmax( + dice_model.model.predict_proba(convert_data([factual_instance]))) + + def model_pred(x): + # Use one against all strategy + pred_prob = dice_model.model.predict_proba(convert_data(x)) + class_f_proba = pred_prob[:, self.factual_class_idx] + + # Probability for all other classes (excluding class 0) + not_class_f_proba = 1 - class_f_proba + + # Normalize to sum to 1 + class_f_proba = class_f_proba / (class_f_proba + not_class_f_proba) + + return class_f_proba + else: + def model_pred(x): + return dice_model.model.predict(dice_model.transformer.transform(convert_data(x))) + + for cf_instance in cf_examples.final_cfs_df.to_numpy(): + counterplots_out.append( + CreatePlot( + factual=factual_instance, + cf=cf_instance[:-1], + model_pred=model_pred, + feature_names=self.features_names, + )) + return counterplots_out + @staticmethod def _check_cf_exp_output_against_json_schema( cf_dict, version): diff --git a/requirements.txt b/requirements.txt index 7f89d1f4..1c2686cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pandas<2.0.0 scikit-learn tqdm raiutils>=0.4.0 +counterplots>=0.0.7 \ No newline at end of file diff --git a/tests/test_counterfactual_explanations.py b/tests/test_counterfactual_explanations.py index 4dcb5628..7593701a 100644 --- a/tests/test_counterfactual_explanations.py +++ b/tests/test_counterfactual_explanations.py @@ -1,5 +1,9 @@ import json +import unittest +from unittest.mock import Mock, patch +import numpy as np +import pandas as pd import pytest from raiutils.exceptions import UserConfigValidationException @@ -319,3 +323,78 @@ def test_unsupported_versions_to_json(self, unsupported_version): counterfactual_explanations.to_json() assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve) + + +class TestCounterfactualExplanationsPlot(unittest.TestCase): + + @patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot") + def test_plot_counterplots_sklearn(self, mock_create_plot): + # Dummy DiCE's model object with a Sklearn backend + dummy_model = Mock() + dummy_model.backend = "sklearn" + dummy_model.model.predict_proba = Mock(return_value=np.array([[0.4, 0.6], [0.2, 0.8]])) + + # Sample cf_examples to test with + cf_examples_mock = Mock() + cf_examples_mock.test_instance_df = pd.DataFrame({ + 'feature1': [1], + 'feature2': [2], + 'target': [0] + }) + cf_examples_mock.final_cfs_df = pd.DataFrame({ + 'feature1': [1.1, 1.2], + 'feature2': [2.1, 2.2], + 'target': [1, 1] + }) + + counterfact = CounterfactualExplanations( + cf_examples_list=[cf_examples_mock], + local_importance=None, + summary_importance=None, + version=None) + + # Call function + result = counterfact.plot_counterplots(dummy_model) + + # Assert the CreatePlot was called twice (as there are 2 counterfactual instances) + assert mock_create_plot.call_count == 2 + + # Assert that the result is as expected + assert result == ["dummy_plot", "dummy_plot"] + + @patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot") + def test_plot_counterplots_non_sklearn(self, mock_create_plot): + # Sample Non-Sklearn backend + dummy_model = Mock() + dummy_model.backend = "NonSklearn" + dummy_model.model.predict = Mock(return_value=np.array([0, 1])) + dummy_model.transformer = Mock() + dummy_model.transformer.transform = Mock(return_value=np.array([[1, 2], [1.1, 2.1]])) + + # Sample cf_examples to test with + cf_examples_mock = Mock() + cf_examples_mock.test_instance_df = pd.DataFrame({ + 'feature1': [1], + 'feature2': [2], + 'target': [0] + }) + cf_examples_mock.final_cfs_df = pd.DataFrame({ + 'feature1': [1.1, 1.2], + 'feature2': [2.1, 2.2], + 'target': [1, 1] + }) + + counterfact = CounterfactualExplanations( + cf_examples_list=[cf_examples_mock], + local_importance=None, + summary_importance=None, + version=None) + + # Call function + result = counterfact.plot_counterplots(dummy_model) + + # Assert the CreatePlot was called twice (as there are 2 counterfactual instances) + assert mock_create_plot.call_count == 2 + + # Assert that the result is as expected + assert result == ["dummy_plot", "dummy_plot"]