From 90f74c8a36d25fe157413665c81501dd864a7c8e Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 11 Sep 2018 17:01:39 +0200 Subject: [PATCH] Update CSFS_SMBA.py --- src/CSFS_SMBA.py | 45 +-------------------------------------------- 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/src/CSFS_SMBA.py b/src/CSFS_SMBA.py index 44bfa23..2de5873 100644 --- a/src/CSFS_SMBA.py +++ b/src/CSFS_SMBA.py @@ -34,50 +34,7 @@ def checkFolder(root, path_output): else: raise -def OCC_DecisioneRule(clf_score, cls, clf_name, target): - n_classes = len(cls) - - DTS = {} - for ccn in clf_name: - hits = [] - res = [] - preds = [] - - for i in xrange(0,n_classes): - e_th = np.asarray(clf_score['C'+str(cls[i])]['accuracy'][ccn]) - - e_th[np.where(e_th==-1)] = 0 - - hits.append(e_th) - - ensemble_hits = np.vstack(hits) - - for i in xrange(0, ensemble_hits.shape[1]): # number of sample - hits = ensemble_hits[:,i] - cond = np.sum(hits) - - if cond == 1: #rule 1 - pred = np.where(hits==1)[0] - pred = cls[pred][0] - preds.append(pred) - elif cond == 0: #rule 2 (tie among all OCC) - pred = cls[rnd.randint(0, len(cls) - 1)] - preds.append(pred) - elif cond > 0: - tied_cls = np.where(hits==1)[0] - pred = tied_cls[rnd.randint(0, len(tied_cls) - 1)] - preds.append(pred) - - test_score = accuracy_score(target, preds) - - dic_test_score = { - ccn: test_score - } - - DTS.update(dic_test_score) - - return DTS def classificationDecisionRule(clf_score, cls, clf_name, target): @@ -426,4 +383,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main()