Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions pymatch/Matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def predict_scores(self):
scores += m.predict(self.X[m.params.index])
self.data['scores'] = scores/self.nmodels

def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10):
def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10, with_replacement=True):
"""
Finds suitable match(es) for each record in the minority
dataset, if one exists. Records are exlcuded from the final
Expand All @@ -165,36 +165,45 @@ def match(self, threshold=0.001, nmatches=1, method='min', max_rand=10):
"min" - choose the profile with the closest score
max_rand : int
max number of profiles to consider when using random tie-breaks
with_replacement : bool
True - matching is performed with replacement, in the
majority group. The same entry from the majority group can be
matched to multiple entries from the minority group
False - matching is performed without replacement, in
the majority group. All matches consist of unique entries.
Matching order is randomized.

Returns
-------
None
"""

if 'scores' not in self.data.columns:
print("Propensity Scores have not been calculated. Using defaults...")
self.fit_scores()
self.predict_scores()
test_scores = self.data[self.data[self.yvar]==True][['scores']]
ctrl_scores = self.data[self.data[self.yvar]==False][['scores']]
result, match_ids = [], []
if with_replacement==False:
test_scores=test_scores.reindex(np.random.permutation(test_scores.index))
for i in range(len(test_scores)):
# uf.progress(i+1, len(test_scores), 'Matching Control to Test...')
match_id = i
score = test_scores.iloc[i]
if method == 'random':
bool_match = abs(ctrl_scores - score) <= threshold
matches = ctrl_scores.loc[bool_match[bool_match.scores].index]
elif method == 'min':
matches = abs(ctrl_scores - score).sort_values('scores').head(nmatches)
else:
raise(AssertionError, "Invalid method parameter, use ('random', 'min')")

bool_match = abs(ctrl_scores - score) <= threshold
matches = ctrl_scores.loc[bool_match[bool_match.scores].index]

if len(matches) == 0:
continue
# randomly choose nmatches indices, if len(matches) > nmatches
select = nmatches if method != 'random' else np.random.choice(range(1, max_rand+1), 1)
chosen = np.random.choice(matches.index, min(select, nmatches), replace=False)
result.extend([test_scores.index[i]] + list(chosen))
match_ids.extend([i] * (len(chosen)+1))
if with_replacement==False:
ctrl_scores['scores'].iloc[list(chosen-len(test_scores))]=999
self.matched_data = self.data.loc[result]
self.matched_data['match_id'] = match_ids
self.matched_data['record_id'] = self.matched_data.index
Expand Down