diff --git a/pymatch/Matcher.py b/pymatch/Matcher.py index c7020bf..0dbb571 100644 --- a/pymatch/Matcher.py +++ b/pymatch/Matcher.py @@ -141,7 +141,7 @@ def predict_scores(self): scores += m.predict(self.X[m.params.index]) self.data['scores'] = scores/self.nmodels - def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10): + def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10, with_replacement=True): """ Finds suitable match(es) for each record in the minority dataset, if one exists. Records are exlcuded from the final @@ -165,11 +165,19 @@ def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10): "min" - choose the profile with the closest score max_rand : int max number of profiles to consider when using random tie-breaks + with_replacement : bool + True - matching is performed with replacement, in the + majority group. The same entry from the majority group can be + matched to multiple entries from the minority group + False - matching is performed without replacement, in + the majority group. All matches consist of unique entries. + Matching order is randomized. Returns ------- None """ + if 'scores' not in self.data.columns: print("Propensity Scores have not been calculated. Using defaults...") self.fit_scores() @@ -177,17 +185,16 @@ def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10): test_scores = self.data[self.data[self.yvar]==True][['scores']] ctrl_scores = self.data[self.data[self.yvar]==False][['scores']] result, match_ids = [], [] + if with_replacement==False: + test_scores=test_scores.reindex(np.random.permutation(test_scores.index)) for i in range(len(test_scores)): # uf.progress(i+1, len(test_scores), 'Matching Control to Test...') match_id = i score = test_scores.iloc[i] - if method == 'random': - bool_match = abs(ctrl_scores - score) <= threshold - matches = ctrl_scores.loc[bool_match[bool_match.scores].index] - elif method == 'min': - matches = abs(ctrl_scores - score).sort_values('scores').head(nmatches) - else: - raise(AssertionError, "Invalid method parameter, use ('random', 'min')") + + bool_match = abs(ctrl_scores - score) <= threshold + matches = ctrl_scores.loc[bool_match[bool_match.scores].index] + if len(matches) == 0: continue # randomly choose nmatches indices, if len(matches) > nmatches @@ -195,6 +202,8 @@ def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10): chosen = np.random.choice(matches.index, min(select, nmatches), replace=False) result.extend([test_scores.index[i]] + list(chosen)) match_ids.extend([i] * (len(chosen)+1)) + if with_replacement==False: + ctrl_scores['scores'].iloc[list(chosen-len(test_scores))]=999 self.matched_data = self.data.loc[result] self.matched_data['match_id'] = match_ids self.matched_data['record_id'] = self.matched_data.index