@@ -25,6 +25,12 @@ def plot_confusion_matrix(cm, classes,
25
25
"""
26
26
This function prints and plots the confusion matrix.
27
27
Normalization can be applied by setting `normalize=True`.
28
+
29
+ :params:
30
+ cm (np.ndarray) : confusion matrix
31
+ normalize (bool) : normalize counts by # of class instance
32
+ title (string) : plot title
33
+ cmap (matplotlib.colors.LinearSegmentedColormap) : color map for image
28
34
"""
29
35
if normalize :
30
36
cm = cm .astype ('float' ) / cm .sum (axis = 1 )[:, np .newaxis ]
@@ -79,24 +85,37 @@ def base_clf_hp(clf,text):
79
85
80
86
def by_class_error_analysis (df ,y_true ,y_pred ,limit ,error ,out_path ):
81
87
"""
82
- False Positive, False Negative estimation in a one-vs-rest way.
88
+ Write to file randomly selected False Positive or False Negative. For multiclass FP,FN are estimated in one-vs-all.
89
+
90
+ :params:
91
+ df (pandas.DataFrame) : data set having `toks` column
92
+ y_true (array) : original labels
93
+ y_pred (array) : predicted labels
94
+ limit (int) : # of FP/FN to be printed
95
+ error (str) : type of misclassification. Choices : `FP` (false positive), `FN` (false negative)
96
+ out_path (str) : folder where errors will be saved
83
97
"""
84
98
99
+ errors = ['FP' ,'FN' ]
100
+
101
+ assert error in errors , "Invalid error choice! Received `{}` : choose from `{}`! " .format (error ,errors )
102
+
85
103
if error == 'FP' :
86
104
out_file = open (os .path .join (out_path , 'error.FP' ),'w+' )
87
- else :
105
+ elif error == 'FN' :
88
106
out_file = open (os .path .join (out_path , 'error.FN' ),'w+' )
89
107
90
108
unique_labels = np .unique (y_true )
91
109
92
110
y_true = np .asarray (y_true )
111
+ y_pred = np .asarray (y_pred )
93
112
94
113
for label in unique_labels :
95
114
out_file .write ("{}\n " .format (str (label ).upper ()))
96
115
97
116
if error == 'FP' :
98
117
error_idx = np .where ((y_true != label ) & (y_pred == label ))[0 ] #take indices
99
- else :
118
+ elif error == 'FN' :
100
119
error_idx = np .where ((y_true == label ) & (y_pred != label ))[0 ] #take indices
101
120
102
121
if len (error_idx ) < 1 :
0 commit comments