From 309d9ca37774d0be5bb527edc7f2694e6c6fb8bf Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Mon, 7 Jul 2014 19:06:47 +0900 Subject: [PATCH] Added grading scripts --- script/grade-dep.py | 23 +++++ script/grade-prediction.py | 37 ++++++++ script/gradekkc.pl | 172 +++++++++++++++++++++++++++++++++++++ script/gradepos.pl | 48 +++++++++++ script/gradews.pl | 74 ++++++++++++++++ script/print-trees.py | 12 +++ 6 files changed, 366 insertions(+) create mode 100755 script/grade-dep.py create mode 100755 script/grade-prediction.py create mode 100755 script/gradekkc.pl create mode 100755 script/gradepos.pl create mode 100755 script/gradews.pl create mode 100755 script/print-trees.py diff --git a/script/grade-dep.py b/script/grade-dep.py new file mode 100755 index 0000000..f1967d8 --- /dev/null +++ b/script/grade-dep.py @@ -0,0 +1,23 @@ +#!/usr/bin/python + +import sys + +ref_file = open(sys.argv[1], "r") +refs = ref_file.readlines() +tst_file = open(sys.argv[2], "r") +tsts = tst_file.readlines() + +total = 0 +correct = 0 +for i, ref_line in enumerate(refs): + tst_line = tsts[i] + ref_line = ref_line.strip() + tst_line = tst_line.strip() + if len(ref_line) > 0: + (rnum, rname, rname1, rpos, rpos1, runder, rhead, rdep) = ref_line.split('\t') + (tnum, tname, tname1, tpos, tpos1, tunder, thead, tdep) = tst_line.split('\t') + total += 1 + if rhead == thead: + correct += 1 + +print "%f%% (%d/%d)" % (float(correct)/total*100, correct, total) diff --git a/script/grade-prediction.py b/script/grade-prediction.py new file mode 100755 index 0000000..8fd7c8e --- /dev/null +++ b/script/grade-prediction.py @@ -0,0 +1,37 @@ +#!/usr/bin/python + +import sys +import math + +ref = list() +test = list() + +# Load the reference file +ref_file = open(sys.argv[1], "r") +for line in ref_file: + line = line.strip() + columns = line.split("\t") + ref.append(int(columns[0])) +ref_file.close() + +# Load the testing file +test_file = open(sys.argv[2], "r") +for line in test_file: + line = line.strip() + columns = line.split("\t") + test.append(int(columns[0])) +test_file.close() + +# Check to make sure that both are the same length +if len(test) != len(ref): + print "Lengths of test (%i) and reference (%i) file don't match" % (len(test), len(ref)) + sys.exit(1) + +total = 0 +correct = 0 +for i in range(0, len(ref)): + total += 1 + if ref[i] == test[i]: + correct += 1 + +print "Accuracy = %f%%" % float(float(correct)/float(total)*100.0) diff --git a/script/gradekkc.pl b/script/gradekkc.pl new file mode 100755 index 0000000..aa66c84 --- /dev/null +++ b/script/gradekkc.pl @@ -0,0 +1,172 @@ +#!/usr/bin/perl + +# This is a script to grade word error rates according to edit distance + +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +use utf8; +use strict; +use List::Util qw(max min); +use Cwd qw(cwd); + +my $PRINT_INLINE = 1; + +sub width { + $_ = shift; + my $ret = 0; + for(split(//)) { + $ret += ((/\p{InKatakana}/ or /\p{InHiragana}/ or /\p{InCJKSymbolsAndPunctuation}/ or /\p{InKatakanaPhoneticExtensions}/ or /\p{InCJKUnifiedIdeographs}/)?2:1); + } + return $ret; +} + +sub pad { + my ($s, $l) = @_; + return $s . (' ' x ($l-width($s))); +} + +# return a minimum edit distance path with the following notation, plus cost +# d=delete i=insert s=substitute e=equal +sub levenshtein { + my ($s, $t) = @_; + my (@sc, $m, @tc, $n, %d, %str, $i, $j, $id, $cost, $type, $aid, $bid, $cid, $c); + @sc = split(/ +/, $s); + $m = @sc; + @tc = split(/ +/, $t); + $n = @tc; + # initialize + foreach $i (0 .. $m) { $id = pack('S2', $i, 0); $d{$id} = $i; $str{$id} = 'd'x$i; } + foreach $j (1 .. $n) { $id = pack('S2', 0, $j); $d{$id} = $j; $str{$id} = 'i'x$j; } + + foreach $i (1 .. $m) { + foreach $j (1 .. $n) { + if($sc[$i-1] eq $tc[$j-1]) { + $cost = 0; $type = 'e'; # equal + } else { + $cost = 1.1; $type = 's'; # substitution + } + + $aid = pack('S2', $i-1, $j); $a = $d{$aid} + 1; # deletion + $bid = pack('S2', $i, $j-1); $b = $d{$bid} + 1; # insertion + $cid = pack('S2', $i-1, $j-1); $c = $d{$cid} + $cost; # insertion + + $id = pack('S2', $i, $j); + + # we want matches to come at the end, so do deletions/insertions first + if($a <= $b and $a <= $c) { + $d{$id} = $a; + $type = 'd'; + $str{$id} = $str{$aid}.'d'; + } + elsif($b <= $c) { + $d{$id} = $b; + $type = 'i'; + $str{$id} = $str{$bid}.'i'; + } + else { + $d{$id} = $c; + $str{$id} = $str{$cid}.$type; + } + + delete $d{$cid}; + delete $str{$cid}; + # print "".$sc[$i-1]." ".$tc[$j-1]." $i $j $a $b $c $d[$id] $type\n" + } + } + + $id = pack('S2', $m, $n); + return ($str{$id}, $d{$id}); +} + +open REF, "<:utf8", $ARGV[0]; +open TEST, "<:utf8", $ARGV[1]; + +my ($reflen, $testlen); +my %scores = (); +my($ref, $test, $sent, $sentacc); +while($ref = and $test = ) { + chomp $ref; + chomp $test; + $ref =~ s/ //g; $ref =~ s/(.)/$1 /g; $ref =~ s/ $//g; + $test =~ s/ //g; $test =~ s/(.)/$1 /g; $test =~ s/ $//g; + + # get the arrays + my @ra = split(/ +/, $ref); + $reflen += @ra; + my @ta = split(/ +/, $test); + $testlen += @ta; + + # do levenshtein distance if the scores aren't equal + my ($hist, $score); + if ($ref eq $test) { + $sentacc++; + for (@ra) { $hist .= 'e'; } + $score = 0; + } else { + ($hist, $score) = levenshtein($ref, $test); + } + $sent++; + + my @ha = split(//, $hist); + my ($rd, $td, $hd, $h, $r, $t, $l); + if(not $PRINT_INLINE) { + while(@ha) { + $h = shift(@ha); + $scores{$h}++; + if($h eq 'e' or $h eq 's') { + $r = shift(@ra); + $t = shift(@ta); + } elsif ($h eq 'i') { + $r = ''; + $t = shift(@ta); + } elsif ($h eq 'd') { + $r = shift(@ra); + $t = ''; + } else { die "bad history value $h"; } + # find the length + $l = max(width($r), width($t)) + 1; + $rd .= pad($r, $l); + $td .= pad($t, $l); + $hd .= pad($h, $l); + } + print "$rd\n$td\n$hd\n\n"; + } else { + my (@er, @et, @dr, @dt); + while(@ha) { + $h = shift(@ha); + $scores{$h}++; + if($h eq 'e') { + if(@dr or @dt) { + print "X\t@dr\t@dt\n"; @dr = (); @dt = (); + } + push @er, shift(@ra); + push @et, shift(@ta); + } else { + if(@er or @et) { + die "@er != @et" if("@er" ne "@et"); + print "O\t@er\t@et\n"; @er = (); @et = (); + } + push @dr, shift(@ra) if $h ne 'i'; + push @dt, shift(@ta) if $h ne 'd'; + } + } + if(@dr or @dt) { print "X\t@dr\t@dt\n\n"; } + elsif(@er or @et) { print "O\t@er\t@et\n\n"; } + } + die "non-empty ra=@ra or ta=@ta\n" if(@ra or @ta); +} + +my $total = 0; +for (values %scores) { $total += $_; } +foreach my $k (keys %scores) { + print "$k: $scores{$k} (".$scores{$k}/$total*100 . "%)\n"; +} +my $wer = ($scores{'s'}+$scores{'i'}+$scores{'d'})/$reflen*100; +my $prec = $scores{'e'}/$testlen*100; +my $rec = $scores{'e'}/$reflen*100; +my $fmeas = (2*$prec*$rec)/($prec+$rec); +$sentacc = $sentacc/$sent*100; +printf ("WER: %.2f%%\nPrec: %.2f%%\nRec: %.2f%%\nF-meas: %.2f%%\nSent: %.2f%%\n", $wer, $prec, $rec, $fmeas, $sentacc); + diff --git a/script/gradepos.pl b/script/gradepos.pl new file mode 100755 index 0000000..dd757ec --- /dev/null +++ b/script/gradepos.pl @@ -0,0 +1,48 @@ +#!/usr/bin/perl + +use strict; +use utf8; +use Getopt::Long; +use List::Util qw(sum min max shuffle); +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +if(@ARGV != 2) { + print STDERR "Usage: $0 REFERENCE TEST\n"; + exit 1; +} + +open FILE0, "<:utf8", $ARGV[0] or die "Couldn't open $ARGV[0]\n"; +open FILE1, "<:utf8", $ARGV[1] or die "Couldn't open $ARGV[1]\n"; + +my (%mistakes, $total, $correct); +my ($s0, $s1); +while(($s0 = ) and ($s1 = )) { + chomp $s0; chomp $s1; + my @a0 = split(/ /, $s0); + my @a1 = split(/ /, $s1); + if(@a0 != @a1) { + print STDERR "Line lengths don't match:\n@a0\n@a1\n"; + exit 1; + } + foreach my $i (0 .. $#a0) { + $a0[$i] =~ s/\S*_//g; + $a1[$i] =~ s/\S*_//g; + $total++; + if ($a0[$i] eq $a1[$i]) { + $correct++; + } else { + $mistakes{"$a0[$i] --> $a1[$i]"}++; + } + } +} + +printf "Accuracy: %.02f%% (%d/%d)\n\nMost common mistakes:\n", $correct/$total*100, $correct, $total; + +my $mist_count; +for(sort { $mistakes{$b} <=> $mistakes{$a} } keys %mistakes) { + last if $mist_count++ >= 10; + print "$_\t$mistakes{$_}\n"; +} + diff --git a/script/gradews.pl b/script/gradews.pl new file mode 100755 index 0000000..46e3895 --- /dev/null +++ b/script/gradews.pl @@ -0,0 +1,74 @@ +#!/usr/bin/perl +use strict; + +sub lengths { + my @ret; + my @bounds = (@_,1); + my $last = -1; + for(0 .. $#bounds) { + if($bounds[$_]) { + push @ret, ($_-$last); + $last = $_; + } + } + return @ret; +} + +open REF, "<:utf8", $ARGV[0] or die $!; +open TEST, "<:utf8", $ARGV[1] or die $!; +my ($ref, $test); +my ($totb, $corb, $refw, $testw, $corw,$tots,$cors); +while($ref = and $test = ) { + chomp $ref; chomp $test; + $tots++; + $cors++ if ($ref eq $test); + $ref =~ s/\/[^ ]*//g; + $test =~ s/\/[^ ]*//g; + my @rarr = split(//, $ref); + my @tarr = split(//, $test); + shift @rarr; + shift @tarr; + my (@rb,@tb); + while(@rarr and @tarr) { + my $rs = ($rarr[0] eq ' '); + shift @rarr if $rs; + push @rb, $rs; + my $ts = ($tarr[0] eq ' '); + shift @tarr if $ts; + push @tb, $ts; + shift @tarr; + shift @rarr; + } + die "mismatched lines \n$ref\n$test\n" if(@rb != @tb); + # total boundaries + $totb += @rb; + # correct boundaries + for(0 .. $#rb) { $corb++ if ($rb[$_] == $tb[$_]); } + # find word counts + my @rlens = lengths(@rb); + $refw += @rlens; + my @tlens = lengths(@tb); + $testw += @tlens; + # print "$ref\n@rlens\n$test\n@tlens\n"; + # find word matches + my ($rlast, $tlast); + while(@rlens and @tlens) { + if($rlast == $tlast) { + $corw++ if($rlens[0] == $tlens[0]); + } + if($rlast <= $tlast) { + $rlast += shift(@rlens); + } + if($tlast < $rlast) { + $tlast += shift(@tlens); + } + } +} +print "Sent Accuracy: ".sprintf("%.2f", ($cors/$tots*100))."% ($cors/$tots)\n"; +my $precw = $corw/$testw; +my $recw = $corw/$refw; +my $fmeasw = (2*$precw*$recw)/($precw+$recw); +print "Word Prec: ".sprintf("%.2f", ($precw*100))."% ($corw/$testw)\n"; +print "Word Rec: ".sprintf("%.2f", ($recw*100))."% ($corw/$refw)\n"; +print "F-meas: ".sprintf("%.2f", ($fmeasw*100))."%\n"; +print "Bound Accuracy: ".sprintf("%.2f", ($corb/$totb*100))."% ($corb/$totb)\n"; diff --git a/script/print-trees.py b/script/print-trees.py new file mode 100755 index 0000000..e4e8432 --- /dev/null +++ b/script/print-trees.py @@ -0,0 +1,12 @@ +#!/usr/bin/python + +from nltk.tree import Tree +import sys + +# A program to display parse trees (in Penn treebank format) with NLTK +# +# To install NLTK on ubuntu: sudo apt-get install python-nltk + +for line in sys.stdin: + t = Tree.parse(line) + t.draw() \ No newline at end of file