Skip to content

Commit d70ae9e

Browse files
author
sgarda
committed
t-test opt + resutls
1 parent f0308e1 commit d70ae9e

File tree

5 files changed

+491
-133
lines changed

5 files changed

+491
-133
lines changed

classify_tweets.py

+158-133
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from sklearn.metrics import classification_report, confusion_matrix
1212

13+
from sklearn.utils import shuffle
14+
1315
import numpy as np
1416
import argparse
1517
import os
@@ -120,7 +122,7 @@ def parse_arguments():
120122
record_group.add_argument("--save", type = str, default = False, help="If true it writes a file with information about the test, else it just prints it")
121123
record_group.add_argument("--confusion-matrix",action = "store_true", help="Display confusion matrix")
122124
record_group.add_argument("--error-analysis", type = str, default = False, help="Save to a file (path to be provided) tweets misclassified")
123-
125+
record_group.add_argument("--stat-test", type = str, default = False, help="Create score files for statistical significance test (paired t-test)")
124126

125127
return parser.parse_args()
126128

@@ -246,151 +248,174 @@ def parse_arguments():
246248

247249
scoring = ["f1_micro", "f1_macro", "precision_micro", "precision_macro", "recall_micro", "recall_macro"]
248250

249-
f1_scores = cross_validate(clf, X, labels, cv=10, scoring=scoring, return_train_score=False)
250-
251-
y_pred = cross_val_predict(clf, X, labels, cv=10)
252-
253-
report = classification_report(labels, y_pred)
251+
if args.stat_test:
252+
253+
with open(args.stat_test, 'w+') as outfile:
254254

255-
text = []
256-
text.append("classifier: {}\n".format(args.classifier))
257-
text.append("class weights: {}\n".format(args.class_weights))
258-
259-
store_hyperparameters(clf,text)
260-
261-
text.append("\n10 fold cross validation\n")
262-
263-
text.append("preprocessing\n")
264-
text.append("remove url : {}\n".format(args.rm_url))
265-
text.append("reduce length : {}\n".format(args.red_len))
266-
text.append("lowercase : {}\n".format(args.lower))
267-
text.append("remove stopwords : {}\n".format(args.rm_sw))
268-
text.append("remove tags and mentions : {}\n".format(args.rm_tagsmen))
269-
text.append("stem : {}\n".format(args.stem))
270-
271-
text.append("features\n")
272-
text.append("ngram_range: {}\n".format(args.ngram_range))
273-
text.append("tfidf: {}\n".format(args.tfidf))
274-
text.append("tsvd : {}\n\n".format(args.tsvd))
275-
text.append("cluster: {}\n".format(args.clusters))
276-
text.append("postags: {}\n".format(args.postags))
277-
text.append("senti net: {}\n".format(args.sentnet))
278-
text.append("senti words: {}\n".format(args.sentiwords))
279-
text.append("subjective score: {}\n".format(args.subjscore))
280-
text.append("pos subjective score: {}\n".format(args.subjscorepos))
281-
text.append("neg subjective score: {}\n".format(args.subjscoreneg))
282-
text.append("bing liu sent words: {}\n".format(args.bingliusent))
283-
text.append("dependency sent words: {}\n".format(args.depsent))
284-
text.append("negated words: {}\n".format(args.negwords))
285-
text.append("scaled features: {}\n".format(args.scale))
286-
text.append("bigram sentiment scores: {}\n".format(args.bigramsent))
287-
text.append("pos bigram sentiment scores: {}\n".format(args.bigramsentpos))
288-
text.append("neg bigram sentiment scores: {}\n".format(args.bigramsentneg))
289-
text.append("unigram sentiment scores: {}\n".format(args.unigramsent))
290-
text.append("pos unigram sentiment scores: {}\n".format(args.unigramsentpos))
291-
text.append("neg unigram sentiment scores: {}\n".format(args.unigramsentneg))
292-
text.append("argument lexicon scores: {}\n".format(args.argscores))
293-
294-
295-
text.append("Feature matrix shape: {}\n".format(X.shape))
296-
297-
text.append("\n")
255+
if 'baseline' in args.stat_test:
256+
257+
print("Set basline parameters")
258+
259+
clf.clfs[0].C = 64
260+
clf.clfs[0].gamma = 2e-3
261+
262+
clf.clfs[1].C = 256
263+
clf.clfs[1].gamma = 2e-3
264+
265+
clf.clfs[2].C = 512
266+
clf.clfs[2].gamma = 2e-3
298267

299-
for score_name, scores in f1_scores.items():
300-
text.append("average {} : {}\n".format(score_name,sum(scores)/len(scores)))
268+
for i in range(10):
269+
270+
X,labels= shuffle(X,labels, random_state = i)
271+
272+
f1_scores = cross_validate(clf, X, labels, cv=10, scoring=scoring, return_train_score=False)
273+
274+
for score_name,scores in f1_scores.items():
275+
276+
if score_name == 'test_f1_macro':
277+
278+
for score in scores:
279+
280+
outfile.write("{}\n".format(score))
281+
282+
else:
283+
284+
301285

302-
text.append(report)
286+
f1_scores = cross_validate(clf, X, labels, cv=10, scoring=scoring, return_train_score=False)
303287

304-
for line in text:
305-
print(line)
306-
307-
308-
# write text to file to keep a record of stuff
309-
if args.save:
310-
preprocess = "rm"
311-
if args.rm_url:
312-
preprocess += "-url"
313-
if args.rm_sw:
314-
preprocess += "-sw"
315-
if args.rm_tagsmen:
316-
preprocess += "-tm"
317-
if args.stem:
318-
preprocess += "-stem"
319-
320-
features = ""
321-
features += "{}gram-".format(args.ngram_range)
322-
if args.tfidf:
323-
features = "tfidf-"
324-
if args.tsvd > 0:
325-
features += "tsvd-{}-".format(args.tsvd)
326-
if args.clusters:
327-
features += "clusters-"
328-
if args.postags:
329-
features += "postags-"
330-
if args.sentnet:
331-
features += "sentnet-"
332-
if args.sentiwords:
333-
features += "sentiwords-"
334-
if args.subjscore:
335-
features += "subjscore-"
336-
if args.subjscorepos:
337-
features += "subjscorepos-"
338-
if args.subjscoreneg:
339-
features += "subjscoreneg-"
340-
if args.bingliusent:
341-
features += "bingliu-"
342-
if args.depsent:
343-
features += "dep-"
344-
if args.negwords:
345-
features += "neg-"
346-
if args.scale:
347-
features += "scale-"
348-
if args.optim_single:
349-
features += "optim-"
350-
if args.bigramsent:
351-
features += "bigramsent-"
352-
if args.bigramsentpos:
353-
features += "bigramsentpos-"
354-
if args.bigramsentneg:
355-
features += "bigramsentneg-"
356-
if args.unigramsent:
357-
features += "unigramsent-"
358-
if args.unigramsentpos:
359-
features += "unigramsentpos-"
360-
if args.unigramsentneg:
361-
features += "unigramsentneg-"
362-
if args.argscores:
363-
features += "argscores-"
364-
365-
filename = "{}_{}_{}10cv.txt".format(args.classifier,preprocess,features)
366-
367-
if not os.path.exists(args.save):
368-
os.mkdir(args.save)
288+
y_pred = cross_val_predict(clf, X, labels, cv=10)
369289

370-
with open(os.path.join(args.save,filename), "w") as f:
371-
f.writelines(text)
290+
report = classification_report(labels, y_pred)
372291

373-
if args.confusion_matrix:
374-
375-
cm = confusion_matrix(labels,y_pred)
376-
np.set_printoptions(precision=2)
377-
plt.figure()
378-
plot_confusion_matrix(cm, classes=np.unique(labels),
379-
title='Confusion Matrix')
380-
292+
text = []
293+
text.append("classifier: {}\n".format(args.classifier))
294+
text.append("class weights: {}\n".format(args.class_weights))
381295

382-
plt.savefig('confustion_matrix.png')
296+
store_hyperparameters(clf,text)
383297

384-
if args.error_analysis:
298+
text.append("\n10 fold cross validation\n")
385299

386-
if not os.path.exists(args.error_analysis):
387-
os.mkdir(args.error_analysis)
300+
text.append("preprocessing\n")
301+
text.append("remove url : {}\n".format(args.rm_url))
302+
text.append("reduce length : {}\n".format(args.red_len))
303+
text.append("lowercase : {}\n".format(args.lower))
304+
text.append("remove stopwords : {}\n".format(args.rm_sw))
305+
text.append("remove tags and mentions : {}\n".format(args.rm_tagsmen))
306+
text.append("stem : {}\n".format(args.stem))
388307

389-
by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FP', out_path = args.error_analysis )
390-
by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FN', out_path = args.error_analysis )
308+
text.append("features\n")
309+
text.append("ngram_range: {}\n".format(args.ngram_range))
310+
text.append("tfidf: {}\n".format(args.tfidf))
311+
text.append("tsvd : {}\n\n".format(args.tsvd))
312+
text.append("cluster: {}\n".format(args.clusters))
313+
text.append("postags: {}\n".format(args.postags))
314+
text.append("senti net: {}\n".format(args.sentnet))
315+
text.append("senti words: {}\n".format(args.sentiwords))
316+
text.append("subjective score: {}\n".format(args.subjscore))
317+
text.append("pos subjective score: {}\n".format(args.subjscorepos))
318+
text.append("neg subjective score: {}\n".format(args.subjscoreneg))
319+
text.append("bing liu sent words: {}\n".format(args.bingliusent))
320+
text.append("dependency sent words: {}\n".format(args.depsent))
321+
text.append("negated words: {}\n".format(args.negwords))
322+
text.append("scaled features: {}\n".format(args.scale))
323+
text.append("bigram sentiment scores: {}\n".format(args.bigramsent))
324+
text.append("pos bigram sentiment scores: {}\n".format(args.bigramsentpos))
325+
text.append("neg bigram sentiment scores: {}\n".format(args.bigramsentneg))
326+
text.append("unigram sentiment scores: {}\n".format(args.unigramsent))
327+
text.append("pos unigram sentiment scores: {}\n".format(args.unigramsentpos))
328+
text.append("neg unigram sentiment scores: {}\n".format(args.unigramsentneg))
329+
text.append("argument lexicon scores: {}\n".format(args.argscores))
330+
391331

332+
text.append("Feature matrix shape: {}\n".format(X.shape))
392333

334+
text.append("\n")
393335

336+
for score_name, scores in f1_scores.items():
337+
text.append("average {} : {}\n".format(score_name,sum(scores)/len(scores)))
338+
339+
text.append(report)
340+
341+
for line in text:
342+
print(line)
343+
344+
345+
# write text to file to keep a record of stuff
346+
if args.save:
347+
preprocess = "rm"
348+
if args.rm_url:
349+
preprocess += "-url"
350+
if args.rm_sw:
351+
preprocess += "-sw"
352+
if args.rm_tagsmen:
353+
preprocess += "-tm"
354+
if args.stem:
355+
preprocess += "-stem"
356+
357+
features = ""
358+
features += "{}gram-".format(args.ngram_range)
359+
if args.tfidf:
360+
features = "tfidf-"
361+
if args.tsvd > 0:
362+
features += "tsvd-{}-".format(args.tsvd)
363+
if args.clusters:
364+
features += "clusters-"
365+
if args.postags:
366+
features += "postags-"
367+
if args.sentnet:
368+
features += "sentnet-"
369+
if args.sentiwords:
370+
features += "sentiwords-"
371+
if args.subjscore:
372+
features += "subjscore-"
373+
if args.bingliusent:
374+
features += "bingliu-"
375+
if args.depsent:
376+
features += "dep-"
377+
if args.negwords:
378+
features += "neg-"
379+
if args.scale:
380+
features += "scale-"
381+
if args.optim_single:
382+
features += "optim-"
383+
if args.bigramsent:
384+
features += "bigramsent-"
385+
if args.unigramsent:
386+
features += "unigramsent-"
387+
if args.argscores:
388+
features += "argscores-"
389+
390+
filename = "{}_{}_{}10cv.txt".format(args.classifier,preprocess,features)
391+
392+
if not os.path.exists(args.save):
393+
os.mkdir(args.save)
394+
395+
with open(os.path.join(args.save,filename), "w") as f:
396+
f.writelines(text)
397+
398+
if args.confusion_matrix:
399+
400+
cm = confusion_matrix(labels,y_pred)
401+
np.set_printoptions(precision=2)
402+
plt.figure()
403+
plot_confusion_matrix(cm, classes=np.unique(labels),
404+
title='Confusion Matrix')
405+
406+
407+
plt.savefig('confustion_matrix.png')
408+
409+
if args.error_analysis:
410+
411+
if not os.path.exists(args.error_analysis):
412+
os.mkdir(args.error_analysis)
413+
414+
by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FP', out_path = args.error_analysis )
415+
by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FN', out_path = args.error_analysis )
416+
417+
418+
394419

395420

396421

0 commit comments

Comments
 (0)