Skip to content

Commit

Permalink
write results of different datasets into different files to avoid con…
Browse files Browse the repository at this point in the history
…flicts
  • Loading branch information
muhanzhang committed May 10, 2020
1 parent f943f7e commit 50f5041
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
*.pyc
result.txt
extracted_features*
acc_results.txt
auc_results.txt
*results.txt
.*
8 changes: 5 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,13 @@ def loop_dataset(g_list, classifier, sample_idxes, optimizer=None, bsize=cmd_arg

# np.savetxt('test_scores.txt', all_scores) # output test predictions

if not classifier.regression:
if not classifier.regression and cmd_args.printAUC:
all_targets = np.array(all_targets)
fpr, tpr, _ = metrics.roc_curve(all_targets, all_scores, pos_label=1)
auc = metrics.auc(fpr, tpr)
avg_loss = np.concatenate((avg_loss, [auc]))
else:
avg_loss = np.concatenate((avg_loss, [0.0]))

return avg_loss

Expand Down Expand Up @@ -223,11 +225,11 @@ def loop_dataset(g_list, classifier, sample_idxes, optimizer=None, bsize=cmd_arg
test_loss[2] = 0.0
print('\033[93maverage test of epoch %d: loss %.5f acc %.5f auc %.5f\033[0m' % (epoch, test_loss[0], test_loss[1], test_loss[2]))

with open('acc_results.txt', 'a+') as f:
with open(cmd_args.data + '_acc_results.txt', 'a+') as f:
f.write(str(test_loss[1]) + '\n')

if cmd_args.printAUC:
with open('auc_results.txt', 'a+') as f:
with open(cmd_args.data + '_auc_results.txt', 'a+') as f:
f.write(str(test_loss[2]) + '\n')

if cmd_args.extract_features:
Expand Down
6 changes: 3 additions & 3 deletions run_DGCNN.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ if [ ${fold} == 0 ]; then
echo "End of cross-validation"
echo "The total running time is $[stop - start] seconds."
echo "The accuracy results for ${DATA} are as follows:"
cat acc_results.txt
echo "Average accuracy is"
tail -10 acc_results.txt | awk '{ sum += $1; n++ } END { if (n > 0) print sum / n; }'
tail -10 ${DATA}_acc_results.txt
echo "Average accuracy and std are"
tail -10 ${DATA}_acc_results.txt | awk '{ sum += $1; sum2 += $1*$1; n++ } END { if (n > 0) print sum / n; print sqrt(sum2 / n - (sum/n) * (sum/n)); }'
else
CUDA_VISIBLE_DEVICES=${GPU} python main.py \
-seed 1 \
Expand Down

0 comments on commit 50f5041

Please sign in to comment.