|
10 | 10 |
|
11 | 11 | from sklearn.metrics import classification_report, confusion_matrix
|
12 | 12 |
|
| 13 | +from sklearn.utils import shuffle |
| 14 | + |
13 | 15 | import numpy as np
|
14 | 16 | import argparse
|
15 | 17 | import os
|
@@ -120,7 +122,7 @@ def parse_arguments():
|
120 | 122 | record_group.add_argument("--save", type = str, default = False, help="If true it writes a file with information about the test, else it just prints it")
|
121 | 123 | record_group.add_argument("--confusion-matrix",action = "store_true", help="Display confusion matrix")
|
122 | 124 | record_group.add_argument("--error-analysis", type = str, default = False, help="Save to a file (path to be provided) tweets misclassified")
|
123 |
| - |
| 125 | + record_group.add_argument("--stat-test", type = str, default = False, help="Create score files for statistical significance test (paired t-test)") |
124 | 126 |
|
125 | 127 | return parser.parse_args()
|
126 | 128 |
|
@@ -246,151 +248,174 @@ def parse_arguments():
|
246 | 248 |
|
247 | 249 | scoring = ["f1_micro", "f1_macro", "precision_micro", "precision_macro", "recall_micro", "recall_macro"]
|
248 | 250 |
|
249 |
| - f1_scores = cross_validate(clf, X, labels, cv=10, scoring=scoring, return_train_score=False) |
250 |
| - |
251 |
| - y_pred = cross_val_predict(clf, X, labels, cv=10) |
252 |
| - |
253 |
| - report = classification_report(labels, y_pred) |
| 251 | + if args.stat_test: |
| 252 | + |
| 253 | + with open(args.stat_test, 'w+') as outfile: |
254 | 254 |
|
255 |
| - text = [] |
256 |
| - text.append("classifier: {}\n".format(args.classifier)) |
257 |
| - text.append("class weights: {}\n".format(args.class_weights)) |
258 |
| - |
259 |
| - store_hyperparameters(clf,text) |
260 |
| - |
261 |
| - text.append("\n10 fold cross validation\n") |
262 |
| - |
263 |
| - text.append("preprocessing\n") |
264 |
| - text.append("remove url : {}\n".format(args.rm_url)) |
265 |
| - text.append("reduce length : {}\n".format(args.red_len)) |
266 |
| - text.append("lowercase : {}\n".format(args.lower)) |
267 |
| - text.append("remove stopwords : {}\n".format(args.rm_sw)) |
268 |
| - text.append("remove tags and mentions : {}\n".format(args.rm_tagsmen)) |
269 |
| - text.append("stem : {}\n".format(args.stem)) |
270 |
| - |
271 |
| - text.append("features\n") |
272 |
| - text.append("ngram_range: {}\n".format(args.ngram_range)) |
273 |
| - text.append("tfidf: {}\n".format(args.tfidf)) |
274 |
| - text.append("tsvd : {}\n\n".format(args.tsvd)) |
275 |
| - text.append("cluster: {}\n".format(args.clusters)) |
276 |
| - text.append("postags: {}\n".format(args.postags)) |
277 |
| - text.append("senti net: {}\n".format(args.sentnet)) |
278 |
| - text.append("senti words: {}\n".format(args.sentiwords)) |
279 |
| - text.append("subjective score: {}\n".format(args.subjscore)) |
280 |
| - text.append("pos subjective score: {}\n".format(args.subjscorepos)) |
281 |
| - text.append("neg subjective score: {}\n".format(args.subjscoreneg)) |
282 |
| - text.append("bing liu sent words: {}\n".format(args.bingliusent)) |
283 |
| - text.append("dependency sent words: {}\n".format(args.depsent)) |
284 |
| - text.append("negated words: {}\n".format(args.negwords)) |
285 |
| - text.append("scaled features: {}\n".format(args.scale)) |
286 |
| - text.append("bigram sentiment scores: {}\n".format(args.bigramsent)) |
287 |
| - text.append("pos bigram sentiment scores: {}\n".format(args.bigramsentpos)) |
288 |
| - text.append("neg bigram sentiment scores: {}\n".format(args.bigramsentneg)) |
289 |
| - text.append("unigram sentiment scores: {}\n".format(args.unigramsent)) |
290 |
| - text.append("pos unigram sentiment scores: {}\n".format(args.unigramsentpos)) |
291 |
| - text.append("neg unigram sentiment scores: {}\n".format(args.unigramsentneg)) |
292 |
| - text.append("argument lexicon scores: {}\n".format(args.argscores)) |
293 |
| - |
294 |
| - |
295 |
| - text.append("Feature matrix shape: {}\n".format(X.shape)) |
296 |
| - |
297 |
| - text.append("\n") |
| 255 | + if 'baseline' in args.stat_test: |
| 256 | + |
| 257 | + print("Set basline parameters") |
| 258 | + |
| 259 | + clf.clfs[0].C = 64 |
| 260 | + clf.clfs[0].gamma = 2e-3 |
| 261 | + |
| 262 | + clf.clfs[1].C = 256 |
| 263 | + clf.clfs[1].gamma = 2e-3 |
| 264 | + |
| 265 | + clf.clfs[2].C = 512 |
| 266 | + clf.clfs[2].gamma = 2e-3 |
298 | 267 |
|
299 |
| - for score_name, scores in f1_scores.items(): |
300 |
| - text.append("average {} : {}\n".format(score_name,sum(scores)/len(scores))) |
| 268 | + for i in range(10): |
| 269 | + |
| 270 | + X,labels= shuffle(X,labels, random_state = i) |
| 271 | + |
| 272 | + f1_scores = cross_validate(clf, X, labels, cv=10, scoring=scoring, return_train_score=False) |
| 273 | + |
| 274 | + for score_name,scores in f1_scores.items(): |
| 275 | + |
| 276 | + if score_name == 'test_f1_macro': |
| 277 | + |
| 278 | + for score in scores: |
| 279 | + |
| 280 | + outfile.write("{}\n".format(score)) |
| 281 | + |
| 282 | + else: |
| 283 | + |
| 284 | + |
301 | 285 |
|
302 |
| - text.append(report) |
| 286 | + f1_scores = cross_validate(clf, X, labels, cv=10, scoring=scoring, return_train_score=False) |
303 | 287 |
|
304 |
| - for line in text: |
305 |
| - print(line) |
306 |
| - |
307 |
| - |
308 |
| - # write text to file to keep a record of stuff |
309 |
| - if args.save: |
310 |
| - preprocess = "rm" |
311 |
| - if args.rm_url: |
312 |
| - preprocess += "-url" |
313 |
| - if args.rm_sw: |
314 |
| - preprocess += "-sw" |
315 |
| - if args.rm_tagsmen: |
316 |
| - preprocess += "-tm" |
317 |
| - if args.stem: |
318 |
| - preprocess += "-stem" |
319 |
| - |
320 |
| - features = "" |
321 |
| - features += "{}gram-".format(args.ngram_range) |
322 |
| - if args.tfidf: |
323 |
| - features = "tfidf-" |
324 |
| - if args.tsvd > 0: |
325 |
| - features += "tsvd-{}-".format(args.tsvd) |
326 |
| - if args.clusters: |
327 |
| - features += "clusters-" |
328 |
| - if args.postags: |
329 |
| - features += "postags-" |
330 |
| - if args.sentnet: |
331 |
| - features += "sentnet-" |
332 |
| - if args.sentiwords: |
333 |
| - features += "sentiwords-" |
334 |
| - if args.subjscore: |
335 |
| - features += "subjscore-" |
336 |
| - if args.subjscorepos: |
337 |
| - features += "subjscorepos-" |
338 |
| - if args.subjscoreneg: |
339 |
| - features += "subjscoreneg-" |
340 |
| - if args.bingliusent: |
341 |
| - features += "bingliu-" |
342 |
| - if args.depsent: |
343 |
| - features += "dep-" |
344 |
| - if args.negwords: |
345 |
| - features += "neg-" |
346 |
| - if args.scale: |
347 |
| - features += "scale-" |
348 |
| - if args.optim_single: |
349 |
| - features += "optim-" |
350 |
| - if args.bigramsent: |
351 |
| - features += "bigramsent-" |
352 |
| - if args.bigramsentpos: |
353 |
| - features += "bigramsentpos-" |
354 |
| - if args.bigramsentneg: |
355 |
| - features += "bigramsentneg-" |
356 |
| - if args.unigramsent: |
357 |
| - features += "unigramsent-" |
358 |
| - if args.unigramsentpos: |
359 |
| - features += "unigramsentpos-" |
360 |
| - if args.unigramsentneg: |
361 |
| - features += "unigramsentneg-" |
362 |
| - if args.argscores: |
363 |
| - features += "argscores-" |
364 |
| - |
365 |
| - filename = "{}_{}_{}10cv.txt".format(args.classifier,preprocess,features) |
366 |
| - |
367 |
| - if not os.path.exists(args.save): |
368 |
| - os.mkdir(args.save) |
| 288 | + y_pred = cross_val_predict(clf, X, labels, cv=10) |
369 | 289 |
|
370 |
| - with open(os.path.join(args.save,filename), "w") as f: |
371 |
| - f.writelines(text) |
| 290 | + report = classification_report(labels, y_pred) |
372 | 291 |
|
373 |
| - if args.confusion_matrix: |
374 |
| - |
375 |
| - cm = confusion_matrix(labels,y_pred) |
376 |
| - np.set_printoptions(precision=2) |
377 |
| - plt.figure() |
378 |
| - plot_confusion_matrix(cm, classes=np.unique(labels), |
379 |
| - title='Confusion Matrix') |
380 |
| - |
| 292 | + text = [] |
| 293 | + text.append("classifier: {}\n".format(args.classifier)) |
| 294 | + text.append("class weights: {}\n".format(args.class_weights)) |
381 | 295 |
|
382 |
| - plt.savefig('confustion_matrix.png') |
| 296 | + store_hyperparameters(clf,text) |
383 | 297 |
|
384 |
| - if args.error_analysis: |
| 298 | + text.append("\n10 fold cross validation\n") |
385 | 299 |
|
386 |
| - if not os.path.exists(args.error_analysis): |
387 |
| - os.mkdir(args.error_analysis) |
| 300 | + text.append("preprocessing\n") |
| 301 | + text.append("remove url : {}\n".format(args.rm_url)) |
| 302 | + text.append("reduce length : {}\n".format(args.red_len)) |
| 303 | + text.append("lowercase : {}\n".format(args.lower)) |
| 304 | + text.append("remove stopwords : {}\n".format(args.rm_sw)) |
| 305 | + text.append("remove tags and mentions : {}\n".format(args.rm_tagsmen)) |
| 306 | + text.append("stem : {}\n".format(args.stem)) |
388 | 307 |
|
389 |
| - by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FP', out_path = args.error_analysis ) |
390 |
| - by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FN', out_path = args.error_analysis ) |
| 308 | + text.append("features\n") |
| 309 | + text.append("ngram_range: {}\n".format(args.ngram_range)) |
| 310 | + text.append("tfidf: {}\n".format(args.tfidf)) |
| 311 | + text.append("tsvd : {}\n\n".format(args.tsvd)) |
| 312 | + text.append("cluster: {}\n".format(args.clusters)) |
| 313 | + text.append("postags: {}\n".format(args.postags)) |
| 314 | + text.append("senti net: {}\n".format(args.sentnet)) |
| 315 | + text.append("senti words: {}\n".format(args.sentiwords)) |
| 316 | + text.append("subjective score: {}\n".format(args.subjscore)) |
| 317 | + text.append("pos subjective score: {}\n".format(args.subjscorepos)) |
| 318 | + text.append("neg subjective score: {}\n".format(args.subjscoreneg)) |
| 319 | + text.append("bing liu sent words: {}\n".format(args.bingliusent)) |
| 320 | + text.append("dependency sent words: {}\n".format(args.depsent)) |
| 321 | + text.append("negated words: {}\n".format(args.negwords)) |
| 322 | + text.append("scaled features: {}\n".format(args.scale)) |
| 323 | + text.append("bigram sentiment scores: {}\n".format(args.bigramsent)) |
| 324 | + text.append("pos bigram sentiment scores: {}\n".format(args.bigramsentpos)) |
| 325 | + text.append("neg bigram sentiment scores: {}\n".format(args.bigramsentneg)) |
| 326 | + text.append("unigram sentiment scores: {}\n".format(args.unigramsent)) |
| 327 | + text.append("pos unigram sentiment scores: {}\n".format(args.unigramsentpos)) |
| 328 | + text.append("neg unigram sentiment scores: {}\n".format(args.unigramsentneg)) |
| 329 | + text.append("argument lexicon scores: {}\n".format(args.argscores)) |
| 330 | + |
391 | 331 |
|
| 332 | + text.append("Feature matrix shape: {}\n".format(X.shape)) |
392 | 333 |
|
| 334 | + text.append("\n") |
393 | 335 |
|
| 336 | + for score_name, scores in f1_scores.items(): |
| 337 | + text.append("average {} : {}\n".format(score_name,sum(scores)/len(scores))) |
| 338 | + |
| 339 | + text.append(report) |
| 340 | + |
| 341 | + for line in text: |
| 342 | + print(line) |
| 343 | + |
| 344 | + |
| 345 | + # write text to file to keep a record of stuff |
| 346 | + if args.save: |
| 347 | + preprocess = "rm" |
| 348 | + if args.rm_url: |
| 349 | + preprocess += "-url" |
| 350 | + if args.rm_sw: |
| 351 | + preprocess += "-sw" |
| 352 | + if args.rm_tagsmen: |
| 353 | + preprocess += "-tm" |
| 354 | + if args.stem: |
| 355 | + preprocess += "-stem" |
| 356 | + |
| 357 | + features = "" |
| 358 | + features += "{}gram-".format(args.ngram_range) |
| 359 | + if args.tfidf: |
| 360 | + features = "tfidf-" |
| 361 | + if args.tsvd > 0: |
| 362 | + features += "tsvd-{}-".format(args.tsvd) |
| 363 | + if args.clusters: |
| 364 | + features += "clusters-" |
| 365 | + if args.postags: |
| 366 | + features += "postags-" |
| 367 | + if args.sentnet: |
| 368 | + features += "sentnet-" |
| 369 | + if args.sentiwords: |
| 370 | + features += "sentiwords-" |
| 371 | + if args.subjscore: |
| 372 | + features += "subjscore-" |
| 373 | + if args.bingliusent: |
| 374 | + features += "bingliu-" |
| 375 | + if args.depsent: |
| 376 | + features += "dep-" |
| 377 | + if args.negwords: |
| 378 | + features += "neg-" |
| 379 | + if args.scale: |
| 380 | + features += "scale-" |
| 381 | + if args.optim_single: |
| 382 | + features += "optim-" |
| 383 | + if args.bigramsent: |
| 384 | + features += "bigramsent-" |
| 385 | + if args.unigramsent: |
| 386 | + features += "unigramsent-" |
| 387 | + if args.argscores: |
| 388 | + features += "argscores-" |
| 389 | + |
| 390 | + filename = "{}_{}_{}10cv.txt".format(args.classifier,preprocess,features) |
| 391 | + |
| 392 | + if not os.path.exists(args.save): |
| 393 | + os.mkdir(args.save) |
| 394 | + |
| 395 | + with open(os.path.join(args.save,filename), "w") as f: |
| 396 | + f.writelines(text) |
| 397 | + |
| 398 | + if args.confusion_matrix: |
| 399 | + |
| 400 | + cm = confusion_matrix(labels,y_pred) |
| 401 | + np.set_printoptions(precision=2) |
| 402 | + plt.figure() |
| 403 | + plot_confusion_matrix(cm, classes=np.unique(labels), |
| 404 | + title='Confusion Matrix') |
| 405 | + |
| 406 | + |
| 407 | + plt.savefig('confustion_matrix.png') |
| 408 | + |
| 409 | + if args.error_analysis: |
| 410 | + |
| 411 | + if not os.path.exists(args.error_analysis): |
| 412 | + os.mkdir(args.error_analysis) |
| 413 | + |
| 414 | + by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FP', out_path = args.error_analysis ) |
| 415 | + by_class_error_analysis(df = df, y_true = labels, y_pred = y_pred, limit = 10, error = 'FN', out_path = args.error_analysis ) |
| 416 | + |
| 417 | + |
| 418 | + |
394 | 419 |
|
395 | 420 |
|
396 | 421 |
|
|
0 commit comments