diff --git a/pymatch/Matcher.py b/pymatch/Matcher.py index bf4537f..9fb87ad 100644 --- a/pymatch/Matcher.py +++ b/pymatch/Matcher.py @@ -259,7 +259,7 @@ def prop_test(self, col): else: print("{} is a continuous variable".format(col)) - def compare_continuous(self, save=False, return_table=False): + def compare_continuous(self, save=False, return_table=False, columns=None): """ Plots the ECDFs for continuous features before and after matching. Each chart title contains test results @@ -289,6 +289,12 @@ def compare_continuous(self, save=False, return_table=False): ---------- return_table : bool Should the function a table with tests and statistics? + columns : List[str] + If None, the method plots/returns the comparison for all + inferred continuous columns. If passed in a list of + column names, plots/returns the comparison for the + provided list without verifying if continuous. We + will ignore any column in the `exclude` list passed in. Returns ------- @@ -298,57 +304,64 @@ def compare_continuous(self, save=False, return_table=False): """ test_results = [] - for col in self.matched_data.columns: - if uf.is_continuous(col, self.X) and col not in self.exclude: - # organize data - trb, cob = self.test[col], self.control[col] - tra = self.matched_data[self.matched_data[self.yvar]==True][col] - coa = self.matched_data[self.matched_data[self.yvar]==False][col] - xtb, xcb = ECDF(trb), ECDF(cob) - xta, xca = ECDF(tra),ECDF(coa) - - # before/after stats - std_diff_med_before, std_diff_mean_before = uf.std_diff(trb, cob) - std_diff_med_after, std_diff_mean_after = uf.std_diff(tra, coa) - pb, truthb = uf.grouped_permutation_test(uf.chi2_distance, trb, cob) - pa, trutha = uf.grouped_permutation_test(uf.chi2_distance, tra, coa) - ksb = round(uf.ks_boot(trb, cob, nboots=1000), 6) - ksa = round(uf.ks_boot(tra, coa, nboots=1000), 6) - - # plotting - f, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(12, 5)) - ax1.plot(xcb.x, xcb.y, label='Control', color=self.control_color) - ax1.plot(xtb.x, xtb.y, label='Test', color=self.test_color) - ax1.plot(xcb.x, xcb.y, label='Control', color=self.control_color) - ax1.plot(xtb.x, xtb.y, label='Test', color=self.test_color) - - title_str = ''' - ECDF for {} {} Matching - KS p-value: {} - Grouped Perm p-value: {} - Std. Median Difference: {} - Std. Mean Difference: {} - ''' - ax1.set_title(title_str.format(col, "before", ksb, pb, - std_diff_med_before, std_diff_mean_before)) - ax2.plot(xca.x, xca.y, label='Control') - ax2.plot(xta.x, xta.y, label='Test') - ax2.set_title(title_str.format(col, "after", ksa, pa, - std_diff_med_after, std_diff_mean_after)) - ax2.legend(loc="lower right") - plt.xlim((0, np.percentile(xta.x, 99))) - - test_results.append({ - "var": col, - "ks_before": ksb, - "ks_after": ksa, - "grouped_chisqr_before": pb, - "grouped_chisqr_after": pa, - "std_median_diff_before": std_diff_med_before, - "std_median_diff_after": std_diff_med_after, - "std_mean_diff_before": std_diff_mean_before, - "std_mean_diff_after": std_diff_mean_after - }) + if columns is None: + columns_to_plot = [ + col + for col in self.matched_data.columns + if uf.is_continuous(col, self.X) and col not in self.exclude + ] + else: + columns_to_plot = [col for col in columns if col not in self.exclude] + for col in columns_to_plot: + # organize data + trb, cob = self.test[col], self.control[col] + tra = self.matched_data[self.matched_data[self.yvar]==True][col] + coa = self.matched_data[self.matched_data[self.yvar]==False][col] + xtb, xcb = ECDF(trb), ECDF(cob) + xta, xca = ECDF(tra),ECDF(coa) + + # before/after stats + std_diff_med_before, std_diff_mean_before = uf.std_diff(trb, cob) + std_diff_med_after, std_diff_mean_after = uf.std_diff(tra, coa) + pb, truthb = uf.grouped_permutation_test(uf.chi2_distance, trb, cob) + pa, trutha = uf.grouped_permutation_test(uf.chi2_distance, tra, coa) + ksb = round(uf.ks_boot(trb, cob, nboots=1000), 6) + ksa = round(uf.ks_boot(tra, coa, nboots=1000), 6) + + # plotting + f, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(12, 5)) + ax1.plot(xcb.x, xcb.y, label='Control', color=self.control_color) + ax1.plot(xtb.x, xtb.y, label='Test', color=self.test_color) + ax1.plot(xcb.x, xcb.y, label='Control', color=self.control_color) + ax1.plot(xtb.x, xtb.y, label='Test', color=self.test_color) + + title_str = ''' + ECDF for {} {} Matching + KS p-value: {} + Grouped Perm p-value: {} + Std. Median Difference: {} + Std. Mean Difference: {} + ''' + ax1.set_title(title_str.format(col, "before", ksb, pb, + std_diff_med_before, std_diff_mean_before)) + ax2.plot(xca.x, xca.y, label='Control') + ax2.plot(xta.x, xta.y, label='Test') + ax2.set_title(title_str.format(col, "after", ksa, pa, + std_diff_med_after, std_diff_mean_after)) + ax2.legend(loc="lower right") + plt.xlim((0, np.percentile(xta.x, 99))) + + test_results.append({ + "var": col, + "ks_before": ksb, + "ks_after": ksa, + "grouped_chisqr_before": pb, + "grouped_chisqr_after": pa, + "std_median_diff_before": std_diff_med_before, + "std_median_diff_after": std_diff_med_after, + "std_mean_diff_before": std_diff_mean_before, + "std_mean_diff_after": std_diff_mean_after + }) var_order = [ "var", @@ -364,7 +377,7 @@ def compare_continuous(self, save=False, return_table=False): return pd.DataFrame(test_results)[var_order] if return_table else None - def compare_categorical(self, return_table=False): + def compare_categorical(self, return_table=False, columns=None): """ Plots the proportional differences of each enumerated discete column for test and control. @@ -379,6 +392,12 @@ def compare_categorical(self, return_table=False): return_table : bool Should the function return a table with test results? + columns : List[str] + If None, the method plots/returns the comparison for all + inferred categorical columns. If passed in a list of + column names, plots/returns the comparison for the + provided list without verifying if categorical. We + will ignore any column in the `exclude` list passed in. Return ------ @@ -404,20 +423,27 @@ def prep_plot(data, var, colname): {} | {} ''' test_results = [] - for col in self.matched_data.columns: - if not uf.is_continuous(col, self.X) and col not in self.exclude: - dbefore = prep_plot(self.data, col, colname="before") - dafter = prep_plot(self.matched_data, col, colname="after") - df = dbefore.join(dafter) - test_results_i = self.prop_test(col) - test_results.append(test_results_i) - - # plotting - df.plot.bar(alpha=.8) - plt.title(title_str.format(col, test_results_i["before"], - test_results_i["after"])) - lim = max(.09, abs(df).max().max()) + .01 - plt.ylim((-lim, lim)) + if columns is None: + columns_to_plot = [ + col + for col in self.matched_data.columns + if not uf.is_continuous(col, self.X) and col not in self.exclude + ] + else: + columns_to_plot = [col for col in columns if col not in self.exclude] + for col in columns_to_plot: + dbefore = prep_plot(self.data, col, colname="before") + dafter = prep_plot(self.matched_data, col, colname="after") + df = dbefore.join(dafter) + test_results_i = self.prop_test(col) + test_results.append(test_results_i) + + # plotting + df.plot.bar(alpha=.8) + plt.title(title_str.format(col, test_results_i["before"], + test_results_i["after"])) + lim = max(.09, abs(df).max().max()) + .01 + plt.ylim((-lim, lim)) return pd.DataFrame(test_results)[['var', 'before', 'after']] if return_table else None def prep_prop_test(self, data, var):