Skip to content

Commit

Permalink
added pickle tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Sep 30, 2020
1 parent f48b54c commit 5d375e2
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 46 deletions.
11 changes: 8 additions & 3 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Release Notes

## version 0.2.3.3
## version 0.2.4

### New Features
- added ExplainerDashboard parameter "reponsive" (defaults to True) to make
the dashboard layout reponsive on mobile devices. Set it to False when e.g.
running tests on headless browsers.

### Bug Fixes
- Fixes bug that made RandomForest and xgboost explainer unpicklable
- Fixes bug that made RandomForest and xgboost explainers unpicklable

### Improvements
- Now dashboard is mobile responsive by default
- Added tests for picklability of explainers


## Version 0.2.3
Expand Down
25 changes: 19 additions & 6 deletions explainerdashboard/dashboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def __init__(self, explainer=None, tabs=None,
external_stylesheets=None,
server=True,
url_base_pathname=None,
responsive=True,
importances=True,
model_summary=True,
contributions=True,
Expand Down Expand Up @@ -311,6 +312,14 @@ def __init__(self, explainer=None, tabs=None,
height(int, optional): height of notebookn output cell in pixels, defaults to 800.
external_stylesheets(list, optional): attach dbc themes e.g.
`external_stylesheets=[dbc.themes.FLATLY]`.
server (Flask instance or bool): either an instance of an existing Flask
server to tie the dashboard to, or True in which case a new Flask
server is created.
url_base_pathname (str): url_base_pathname for dashboard,
e.g. "/dashboard". Defaults to None.
responsive (bool): make layout responsive to viewport size
(i.e. reorganize bootstrap columns on small devices). Set to False
when e.g. testing with a headless browser. Defaults to True.
importances(bool, optional): include ImportancesTab, defaults to True.
model_summary(bool, optional): include ModelSummaryTab, defaults to True.
contributions(bool, optional): include ContributionsTab, defaults to True.
Expand All @@ -327,6 +336,7 @@ def __init__(self, explainer=None, tabs=None,
hide_header, header_hide_title, header_hide_selector
self.external_stylesheets = external_stylesheets
self.server, self.url_base_pathname = server, url_base_pathname
self.responsive = responsive

self.app = self._get_dash_app()
self.app.title = title
Expand Down Expand Up @@ -383,7 +393,7 @@ def __init__(self, explainer=None, tabs=None,

print("Calculating dependencies...", flush=True)
explainer_layout.calculate_dependencies()
print("registering callbacks...", flush=True)
print("Registering callbacks...", flush=True)
explainer_layout.register_callbacks(self.app)

def _convert_str_tabs(self, component):
Expand All @@ -403,21 +413,23 @@ def _convert_str_tabs(self, component):
return component

def _get_dash_app(self):
if self.responsive:
meta_tags = [{'name': 'viewport', 'content': 'width=device-width, initial-scale=1.0, maximum-scale=1.2, minimum-scale=0.5,'}]
else:
meta_tags = None
if self.mode=="dash":
if self.external_stylesheets is not None:
app = dash.Dash(server=self.server,
external_stylesheets=self.external_stylesheets,
assets_url_path="",
url_base_pathname=self.url_base_pathname,
meta_tags=[{'name': 'viewport',
'content': 'width=device-width, initial-scale=1.0, maximum-scale=1.2, minimum-scale=0.5,'}])
meta_tags=meta_tags)
app.config['suppress_callback_exceptions'] = True
else:
app = dash.Dash(__name__,
server=self.server,
url_base_pathname=self.url_base_pathname,
meta_tags=[{'name': 'viewport',
'content': 'width=device-width, initial-scale=1.0, maximum-scale=1.2, minimum-scale=0.5,'}])
meta_tags=meta_tags)
app.config['suppress_callback_exceptions'] = True
app.css.config.serve_locally = True
app.scripts.config.serve_locally = True
Expand All @@ -426,7 +438,8 @@ def _get_dash_app(self):
if self.external_stylesheets is not None:
app = JupyterDash(
external_stylesheets=self.external_stylesheets,
assets_url_path="")
assets_url_path="",
meta_tags=meta_tags)
else:
app = JupyterDash(__name__)
return app
Expand Down
56 changes: 22 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,54 @@

setup(
name='explainerdashboard',
version='0.2.3.3',
version='0.2.4',
description='explainerdashboard allows you quickly build an interactive dashboard to explain the inner workings of your machine learning model.',
long_description="""
This package makes it convenient to quickly explain the workings of a
(scikit-learn compatible) fitted machine learning model using either
interactive plots in e.g. Jupyter Notebook or deploying an interactive
dashboard (based on Flask/Dash) that allows you to quickly explore the
impact of different features on model predictions.
This package makes it convenient to quickly deploy a dashboard web app
that explains the workings of a (scikit-learn compatible) fitted machine
learning model. The dashboard provides interactive plots on model performance,
feature importances, feature contributions to individual predictions,
partial dependence plots, SHAP (interaction) values, visualisation of individual
decision trees, etc.
In a lot of organizations, especially governmental, but with the GDPR also
increasingly in private sector, it becomes more and more important to be able
to explain the inner workings of your machine learning algorithms. Customers
have to some extent a right to an explanation why they were selected, and
more and more internal and external regulators require it. With recent
innovations in explainable AI (e.g. SHAP values) the old black box trope is
no longer valid, but it can still take quite a bit of data wrangling and
plot manipulation to get the explanations out of a model. This library aims
to make this easy.
The goal is manyfold:
- Make it easy for data scientists to quickly inspect the workings and
performance of their model in a few lines of code
- Make it easy for data scientists to quickly inspect the inner workings and
performance of their model with just a few lines of code
- Make it possible for non data scientist stakeholders such as managers,
directors, internal and external watchdogs to interactively inspect
the inner workings of the model without having to depend on a data
scientist to generate every plot and table
- Make it easy to build an application that explains individual predictions
of your model for customers that ask for an explanation
- Explain the inner workings of the model to the people working with so
that they gain understanding what the model does and doesn't do.
- Make it easy to build a custom application that explains individual
predictions of your model for customers that ask for an explanation
- Explain the inner workings of the model to the people working with
model in a human-in-the-loop deployment so that they gain understanding
what the model does and doesn't do.
This is important so that they can gain an intuition for when the
model is likely missing information and may have to be overruled.
The library includes:
The dashboard includes:
- Shap values (i.e. what is the contributions of each feature to each
- SHAP values (i.e. what is the contribution of each feature to each
individual prediction?)
- Permutation importances (how much does the model metric deteriorate
when you shuffle a feature?)
- Partial dependence plots (how does the model prediction change when
you vary a single feature?
- Shap interaction values (decompose the shap value into a direct effect
an interaction effects)
- For Random Forests: what is the prediction of each individual decision
tree, and what is the path through each tree? (using dtreeviz)
- For Random Forests and xgboost models: visualization of individual trees
in the ensemble.
- Plus for classifiers: precision plots, confusion matrix, ROC AUC plot,
PR AUC plot, etc
- For regression models: goodness-of-fit plots, residual plots, etc.
The library is designed to be modular so that it should be easy to design your
own interactive dashboards with plotly dash, with most of the work of calculating
and formatting data, and rendering plots and tables handled by explainerdashboard,
so that you can focus on the layout, logic of the interactions, and project specific
textual explanations of the dashboard. (i.e. design it so that it will be interpretable
for business users in your organization, not just data scientists)
Alternatively, there is a built-in standard dashboard with pre-built tabs that
you can select individually.
The library is designed to be modular so that it is easy to design your
own custom dashboards so that you can focus on the layout and project specific
textual explanations of the dashboard. (i.e. design it so that it will be
interpretable for business users in your organization, not just data scientists)
A deployed example can be found at http://titanicexplainer.herokuapp.com
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/test_dashboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ def get_multiclass_explainer():

def test_classification_dashboard(dash_duo):
explainer = get_classification_explainer()
db = ExplainerDashboard(explainer, title="testing")
db = ExplainerDashboard(explainer, title="testing", responsive=False)
dash_duo.start_server(db.app)
dash_duo.wait_for_text_to_equal("h1", "testing", timeout=30)
assert dash_duo.get_logs() == [], "browser console should contain no error"


def test_regression_dashboard(dash_duo):
explainer = get_regression_explainer()
db = ExplainerDashboard(explainer, title="testing")
db = ExplainerDashboard(explainer, title="testing", responsive=False)
dash_duo.start_server(db.app)
dash_duo.wait_for_text_to_equal("h1", "testing", timeout=30)
assert dash_duo.get_logs() == [], "browser console should contain no error"


def test_multiclass_dashboard(dash_duo):
explainer = get_multiclass_explainer()
db = ExplainerDashboard(explainer, title="testing")
db = ExplainerDashboard(explainer, title="testing", responsive=False)
dash_duo.start_server(db.app)
dash_duo.wait_for_text_to_equal("h1", "testing", timeout=30)
assert dash_duo.get_logs() == [], "browser console should contain no error"
93 changes: 93 additions & 0 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import unittest
from pathlib import Path
import pickle

import pandas as pd
import numpy as np

from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from xgboost import XGBClassifier, XGBRegressor

from explainerdashboard.explainers import ClassifierExplainer, RegressionExplainer
from explainerdashboard.datasets import titanic_survive, titanic_fare, titanic_names


class TestRFClassifierExplainerPicklable(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_survive()
train_names, test_names = titanic_names()

model = RandomForestClassifier(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)

self.explainer = ClassifierExplainer(
model, X_test, y_test,
cats=['Sex', 'Cabin', 'Embarked'],
labels=['Not survived', 'Survived'],
idxs=test_names)

def test_rf_pickle(self):
pickle_location = Path.cwd() / "rf_pickle_test.pkl"
pickle.dump(self.explainer, open(str(pickle_location), "wb"))
assert pickle_location.exists
pickle_location.unlink()


class TestXGBClassifierExplainerPicklable(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_survive()
train_names, test_names = titanic_names()

model = XGBClassifier(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)

self.explainer = ClassifierExplainer(
model, X_test, y_test,
cats=['Sex', 'Cabin', 'Embarked'],
labels=['Not survived', 'Survived'],
idxs=test_names)

def test_xgb_pickle(self):
pickle_location = Path.cwd() / "xgb_pickle_test.pkl"
pickle.dump(self.explainer, open(str(pickle_location), "wb"))
assert pickle_location.exists
pickle_location.unlink()

class TestRFRegressionExplainerPicklable(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_fare()
train_names, test_names = titanic_names()

model = RandomForestRegressor(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)

self.explainer = RegressionExplainer(
model, X_test, y_test,
cats=['Sex', 'Cabin', 'Embarked'],
idxs=test_names)

def test_rf_pickle(self):
pickle_location = Path.cwd() / "rf_reg_pickle_test.pkl"
pickle.dump(self.explainer, open(str(pickle_location), "wb"))
assert pickle_location.exists
pickle_location.unlink()


class TestXGBRegressionExplainerPicklable(unittest.TestCase):
def setUp(self):
X_train, y_train, X_test, y_test = titanic_fare()
train_names, test_names = titanic_names()

model = XGBRegressor(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)

self.explainer = RegressionExplainer(
model, X_test, y_test,
cats=['Sex', 'Cabin', 'Embarked'],
idxs=test_names)

def test_xgb_pickle(self):
pickle_location = Path.cwd() / "xgb_reg_pickle_test.pkl"
pickle.dump(self.explainer, open(str(pickle_location), "wb"))
assert pickle_location.exists
pickle_location.unlink()

0 comments on commit 5d375e2

Please sign in to comment.