Skip to content

Commit

Permalink
Update CSFS_SMBA.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DavideNardone authored Sep 11, 2018
1 parent 6c668bb commit 90f74c8
Showing 1 changed file with 1 addition and 44 deletions.
45 changes: 1 addition & 44 deletions src/CSFS_SMBA.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,7 @@ def checkFolder(root, path_output):
else:
raise

def OCC_DecisioneRule(clf_score, cls, clf_name, target):

n_classes = len(cls)

DTS = {}
for ccn in clf_name:
hits = []
res = []
preds = []

for i in xrange(0,n_classes):
e_th = np.asarray(clf_score['C'+str(cls[i])]['accuracy'][ccn])

e_th[np.where(e_th==-1)] = 0

hits.append(e_th)

ensemble_hits = np.vstack(hits)

for i in xrange(0, ensemble_hits.shape[1]): # number of sample
hits = ensemble_hits[:,i]
cond = np.sum(hits)

if cond == 1: #rule 1
pred = np.where(hits==1)[0]
pred = cls[pred][0]
preds.append(pred)
elif cond == 0: #rule 2 (tie among all OCC)
pred = cls[rnd.randint(0, len(cls) - 1)]
preds.append(pred)
elif cond > 0:
tied_cls = np.where(hits==1)[0]
pred = tied_cls[rnd.randint(0, len(tied_cls) - 1)]
preds.append(pred)

test_score = accuracy_score(target, preds)

dic_test_score = {
ccn: test_score
}

DTS.update(dic_test_score)

return DTS

def classificationDecisionRule(clf_score, cls, clf_name, target):

Expand Down Expand Up @@ -426,4 +383,4 @@ def main():


if __name__ == '__main__':
main()
main()

0 comments on commit 90f74c8

Please sign in to comment.