Skip to content

Commit 3c0eed8

Browse files
author
sgarda
committed
add docs 4 utils, fix error analysis
1 parent 528692e commit 3c0eed8

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

classify_tweets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ def parse_arguments():
125125
df = load_data(dep_file = args.tweets_file, annotations = args.annotations)
126126

127127
# replace column of tokens with preprocessed ones
128-
df['toks'] = df['toks_pos'].apply(preprocessing,rm_url = args.rm_url, red_len = args.red_len,lower = args.lower,
128+
df['proc_toks'] = df['toks_pos'].apply(preprocessing,rm_url = args.rm_url, red_len = args.red_len,lower = args.lower,
129129
rm_sw = args.rm_sw, rm_tags_mentions = args.rm_tagsmen, stem = args.stem)
130130
# still dataframe with all columns
131131

132132
print("Shuffling data")
133133
np.random.seed(42)
134134
df = df.reindex(np.random.permutation(df.index))
135135

136-
tweets = list(df['toks'])
136+
tweets = list(df['proc_toks'])
137137
labels = list(df['label'])
138138
pos = list(df['pos'])
139139
deps = list(df['dep'])

utils.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def plot_confusion_matrix(cm, classes,
2525
"""
2626
This function prints and plots the confusion matrix.
2727
Normalization can be applied by setting `normalize=True`.
28+
29+
:params:
30+
cm (np.ndarray) : confusion matrix
31+
normalize (bool) : normalize counts by # of class instance
32+
title (string) : plot title
33+
cmap (matplotlib.colors.LinearSegmentedColormap) : color map for image
2834
"""
2935
if normalize:
3036
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
@@ -79,24 +85,37 @@ def base_clf_hp(clf,text):
7985

8086
def by_class_error_analysis(df,y_true,y_pred,limit,error,out_path):
8187
"""
82-
False Positive, False Negative estimation in a one-vs-rest way.
88+
Write to file randomly selected False Positive or False Negative. For multiclass FP,FN are estimated in one-vs-all.
89+
90+
:params:
91+
df (pandas.DataFrame) : data set having `toks` column
92+
y_true (array) : original labels
93+
y_pred (array) : predicted labels
94+
limit (int) : # of FP/FN to be printed
95+
error (str) : type of misclassification. Choices : `FP` (false positive), `FN` (false negative)
96+
out_path (str) : folder where errors will be saved
8397
"""
8498

99+
errors = ['FP','FN']
100+
101+
assert error in errors, "Invalid error choice! Received `{}` : choose from `{}`! ".format(error,errors)
102+
85103
if error == 'FP':
86104
out_file = open(os.path.join(out_path, 'error.FP'),'w+')
87-
else :
105+
elif error == 'FN' :
88106
out_file = open(os.path.join(out_path, 'error.FN'),'w+')
89107

90108
unique_labels = np.unique(y_true)
91109

92110
y_true = np.asarray(y_true)
111+
y_pred = np.asarray(y_pred)
93112

94113
for label in unique_labels:
95114
out_file.write("{}\n".format(str(label).upper()))
96115

97116
if error == 'FP':
98117
error_idx = np.where((y_true!=label) & (y_pred==label))[0] #take indices
99-
else:
118+
elif error == 'FN':
100119
error_idx = np.where((y_true==label) & (y_pred!=label))[0] #take indices
101120

102121
if len(error_idx) < 1:

0 commit comments

Comments
 (0)