Skip to content


Added grading scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Jul 7, 2014
1 parent 25ce077 commit 309d9ca
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 0 deletions.
23 changes: 23 additions & 0 deletions script/
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

import sys

ref_file = open(sys.argv[1], "r")
refs = ref_file.readlines()
tst_file = open(sys.argv[2], "r")
tsts = tst_file.readlines()

total = 0
correct = 0
for i, ref_line in enumerate(refs):
tst_line = tsts[i]
ref_line = ref_line.strip()
tst_line = tst_line.strip()
if len(ref_line) > 0:
(rnum, rname, rname1, rpos, rpos1, runder, rhead, rdep) = ref_line.split('\t')
(tnum, tname, tname1, tpos, tpos1, tunder, thead, tdep) = tst_line.split('\t')
total += 1
if rhead == thead:
correct += 1

print "%f%% (%d/%d)" % (float(correct)/total*100, correct, total)
37 changes: 37 additions & 0 deletions script/
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

import sys
import math

ref = list()
test = list()

# Load the reference file
ref_file = open(sys.argv[1], "r")
for line in ref_file:
line = line.strip()
columns = line.split("\t")

# Load the testing file
test_file = open(sys.argv[2], "r")
for line in test_file:
line = line.strip()
columns = line.split("\t")

# Check to make sure that both are the same length
if len(test) != len(ref):
print "Lengths of test (%i) and reference (%i) file don't match" % (len(test), len(ref))

total = 0
correct = 0
for i in range(0, len(ref)):
total += 1
if ref[i] == test[i]:
correct += 1

print "Accuracy = %f%%" % float(float(correct)/float(total)*100.0)
172 changes: 172 additions & 0 deletions script/
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@

# This is a script to grade word error rates according to edit distance

binmode STDIN, ":utf8";
binmode STDOUT, ":utf8";
binmode STDERR, ":utf8";

use utf8;
use strict;
use List::Util qw(max min);
use Cwd qw(cwd);


sub width {
$_ = shift;
my $ret = 0;
for(split(//)) {
$ret += ((/\p{InKatakana}/ or /\p{InHiragana}/ or /\p{InCJKSymbolsAndPunctuation}/ or /\p{InKatakanaPhoneticExtensions}/ or /\p{InCJKUnifiedIdeographs}/)?2:1);
return $ret;

sub pad {
my ($s, $l) = @_;
return $s . (' ' x ($l-width($s)));

# return a minimum edit distance path with the following notation, plus cost
# d=delete i=insert s=substitute e=equal
sub levenshtein {
my ($s, $t) = @_;
my (@sc, $m, @tc, $n, %d, %str, $i, $j, $id, $cost, $type, $aid, $bid, $cid, $c);
@sc = split(/ +/, $s);
$m = @sc;
@tc = split(/ +/, $t);
$n = @tc;
# initialize
foreach $i (0 .. $m) { $id = pack('S2', $i, 0); $d{$id} = $i; $str{$id} = 'd'x$i; }
foreach $j (1 .. $n) { $id = pack('S2', 0, $j); $d{$id} = $j; $str{$id} = 'i'x$j; }

foreach $i (1 .. $m) {
foreach $j (1 .. $n) {
if($sc[$i-1] eq $tc[$j-1]) {
$cost = 0; $type = 'e'; # equal
} else {
$cost = 1.1; $type = 's'; # substitution

$aid = pack('S2', $i-1, $j); $a = $d{$aid} + 1; # deletion
$bid = pack('S2', $i, $j-1); $b = $d{$bid} + 1; # insertion
$cid = pack('S2', $i-1, $j-1); $c = $d{$cid} + $cost; # insertion

$id = pack('S2', $i, $j);

# we want matches to come at the end, so do deletions/insertions first
if($a <= $b and $a <= $c) {
$d{$id} = $a;
$type = 'd';
$str{$id} = $str{$aid}.'d';
elsif($b <= $c) {
$d{$id} = $b;
$type = 'i';
$str{$id} = $str{$bid}.'i';
else {
$d{$id} = $c;
$str{$id} = $str{$cid}.$type;

delete $d{$cid};
delete $str{$cid};
# print "".$sc[$i-1]." ".$tc[$j-1]." $i $j $a $b $c $d[$id] $type\n"

$id = pack('S2', $m, $n);
return ($str{$id}, $d{$id});

open REF, "<:utf8", $ARGV[0];
open TEST, "<:utf8", $ARGV[1];

my ($reflen, $testlen);
my %scores = ();
my($ref, $test, $sent, $sentacc);
while($ref = <REF> and $test = <TEST>) {
chomp $ref;
chomp $test;
$ref =~ s/ //g; $ref =~ s/(.)/$1 /g; $ref =~ s/ $//g;
$test =~ s/ //g; $test =~ s/(.)/$1 /g; $test =~ s/ $//g;

# get the arrays
my @ra = split(/ +/, $ref);
$reflen += @ra;
my @ta = split(/ +/, $test);
$testlen += @ta;

# do levenshtein distance if the scores aren't equal
my ($hist, $score);
if ($ref eq $test) {
for (@ra) { $hist .= 'e'; }
$score = 0;
} else {
($hist, $score) = levenshtein($ref, $test);

my @ha = split(//, $hist);
my ($rd, $td, $hd, $h, $r, $t, $l);
if(not $PRINT_INLINE) {
while(@ha) {
$h = shift(@ha);
if($h eq 'e' or $h eq 's') {
$r = shift(@ra);
$t = shift(@ta);
} elsif ($h eq 'i') {
$r = '';
$t = shift(@ta);
} elsif ($h eq 'd') {
$r = shift(@ra);
$t = '';
} else { die "bad history value $h"; }
# find the length
$l = max(width($r), width($t)) + 1;
$rd .= pad($r, $l);
$td .= pad($t, $l);
$hd .= pad($h, $l);
print "$rd\n$td\n$hd\n\n";
} else {
my (@er, @et, @dr, @dt);
while(@ha) {
$h = shift(@ha);
if($h eq 'e') {
if(@dr or @dt) {
print "X\t@dr\t@dt\n"; @dr = (); @dt = ();
push @er, shift(@ra);
push @et, shift(@ta);
} else {
if(@er or @et) {
die "@er != @et" if("@er" ne "@et");
print "O\t@er\t@et\n"; @er = (); @et = ();
push @dr, shift(@ra) if $h ne 'i';
push @dt, shift(@ta) if $h ne 'd';
if(@dr or @dt) { print "X\t@dr\t@dt\n\n"; }
elsif(@er or @et) { print "O\t@er\t@et\n\n"; }
die "non-empty ra=@ra or ta=@ta\n" if(@ra or @ta);

my $total = 0;
for (values %scores) { $total += $_; }
foreach my $k (keys %scores) {
print "$k: $scores{$k} (".$scores{$k}/$total*100 . "%)\n";
my $wer = ($scores{'s'}+$scores{'i'}+$scores{'d'})/$reflen*100;
my $prec = $scores{'e'}/$testlen*100;
my $rec = $scores{'e'}/$reflen*100;
my $fmeas = (2*$prec*$rec)/($prec+$rec);
$sentacc = $sentacc/$sent*100;
printf ("WER: %.2f%%\nPrec: %.2f%%\nRec: %.2f%%\nF-meas: %.2f%%\nSent: %.2f%%\n", $wer, $prec, $rec, $fmeas, $sentacc);

48 changes: 48 additions & 0 deletions script/
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

use strict;
use utf8;
use Getopt::Long;
use List::Util qw(sum min max shuffle);
binmode STDIN, ":utf8";
binmode STDOUT, ":utf8";
binmode STDERR, ":utf8";

if(@ARGV != 2) {
print STDERR "Usage: $0 REFERENCE TEST\n";
exit 1;

open FILE0, "<:utf8", $ARGV[0] or die "Couldn't open $ARGV[0]\n";
open FILE1, "<:utf8", $ARGV[1] or die "Couldn't open $ARGV[1]\n";

my (%mistakes, $total, $correct);
my ($s0, $s1);
while(($s0 = <FILE0>) and ($s1 = <FILE1>)) {
chomp $s0; chomp $s1;
my @a0 = split(/ /, $s0);
my @a1 = split(/ /, $s1);
if(@a0 != @a1) {
print STDERR "Line lengths don't match:\n@a0\n@a1\n";
exit 1;
foreach my $i (0 .. $#a0) {
$a0[$i] =~ s/\S*_//g;
$a1[$i] =~ s/\S*_//g;
if ($a0[$i] eq $a1[$i]) {
} else {
$mistakes{"$a0[$i] --> $a1[$i]"}++;

printf "Accuracy: %.02f%% (%d/%d)\n\nMost common mistakes:\n", $correct/$total*100, $correct, $total;

my $mist_count;
for(sort { $mistakes{$b} <=> $mistakes{$a} } keys %mistakes) {
last if $mist_count++ >= 10;
print "$_\t$mistakes{$_}\n";

74 changes: 74 additions & 0 deletions script/
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use strict;

sub lengths {
my @ret;
my @bounds = (@_,1);
my $last = -1;
for(0 .. $#bounds) {
if($bounds[$_]) {
push @ret, ($_-$last);
$last = $_;
return @ret;

open REF, "<:utf8", $ARGV[0] or die $!;
open TEST, "<:utf8", $ARGV[1] or die $!;
my ($ref, $test);
my ($totb, $corb, $refw, $testw, $corw,$tots,$cors);
while($ref = <REF> and $test = <TEST>) {
chomp $ref; chomp $test;
$cors++ if ($ref eq $test);
$ref =~ s/\/[^ ]*//g;
$test =~ s/\/[^ ]*//g;
my @rarr = split(//, $ref);
my @tarr = split(//, $test);
shift @rarr;
shift @tarr;
my (@rb,@tb);
while(@rarr and @tarr) {
my $rs = ($rarr[0] eq ' ');
shift @rarr if $rs;
push @rb, $rs;
my $ts = ($tarr[0] eq ' ');
shift @tarr if $ts;
push @tb, $ts;
shift @tarr;
shift @rarr;
die "mismatched lines \n$ref\n$test\n" if(@rb != @tb);
# total boundaries
$totb += @rb;
# correct boundaries
for(0 .. $#rb) { $corb++ if ($rb[$_] == $tb[$_]); }
# find word counts
my @rlens = lengths(@rb);
$refw += @rlens;
my @tlens = lengths(@tb);
$testw += @tlens;
# print "$ref\n@rlens\n$test\n@tlens\n";
# find word matches
my ($rlast, $tlast);
while(@rlens and @tlens) {
if($rlast == $tlast) {
$corw++ if($rlens[0] == $tlens[0]);
if($rlast <= $tlast) {
$rlast += shift(@rlens);
if($tlast < $rlast) {
$tlast += shift(@tlens);
print "Sent Accuracy: ".sprintf("%.2f", ($cors/$tots*100))."% ($cors/$tots)\n";
my $precw = $corw/$testw;
my $recw = $corw/$refw;
my $fmeasw = (2*$precw*$recw)/($precw+$recw);
print "Word Prec: ".sprintf("%.2f", ($precw*100))."% ($corw/$testw)\n";
print "Word Rec: ".sprintf("%.2f", ($recw*100))."% ($corw/$refw)\n";
print "F-meas: ".sprintf("%.2f", ($fmeasw*100))."%\n";
print "Bound Accuracy: ".sprintf("%.2f", ($corb/$totb*100))."% ($corb/$totb)\n";
12 changes: 12 additions & 0 deletions script/
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

from nltk.tree import Tree
import sys

# A program to display parse trees (in Penn treebank format) with NLTK
# To install NLTK on ubuntu: sudo apt-get install python-nltk

for line in sys.stdin:
t = Tree.parse(line)

0 comments on commit 309d9ca

Please sign in to comment.