Skip to content

Commit

Permalink
Measure functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Amartya Sanyal committed Jul 13, 2017
1 parent ae3d409 commit f871177
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 74 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ SET(luasrc
TotalDropout.lua
VRClassReward.lua
ReverseUnreverse.lua
measure.lua
deprecated/SeqLSTMP.lua
deprecated/SeqReverseSequence.lua
deprecated/BiSequencerLM.lua
Expand Down
4 changes: 3 additions & 1 deletion init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ paths.require 'librnn'
unpack = unpack or table.unpack

require('rnn.utils')

-- extensions to existing nn.Module
require('rnn.Module')
require('rnn.Container')
Expand Down Expand Up @@ -112,6 +111,9 @@ require('rnn.SeqLSTMP')
require('rnn.SeqReverseSequence')
require('rnn.BiSequencerLM')


require('rnn.measure')

-- prevent likely name conflicts
nn.rnn = rnn

Expand Down
27 changes: 27 additions & 0 deletions measure.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function nn.get_bleu(cand, ref, n)
n = n or 4
local smooth = 1
if type(cand) ~= 'table' then
cand = cand:totable()
end
if type(ref) ~= 'table' then
ref = ref:totable()
end
local res = nn.utils.get_ngram_prec(cand, ref, n)
local brevPen = math.exp(1-math.max(1, #ref/#cand))
local correct = 0
local total = 0
local bleu = 1
for i = 1, n do
if res[i][1] > 0 then
if res[i][2] == 0 then
smooth = smooth*0.5
res[i][2] = smooth
end
local prec = res[i][2]/res[i][1]
bleu = bleu * prec
end
end
bleu = bleu^(1/n)
return bleu*brevPen
end
73 changes: 1 addition & 72 deletions scripts/evaluate-rnnlm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,77 +42,6 @@ local validerr = xplog.valnceloss or xplog.valppl

print(string.format("Error (epoch=%d): training=%f; validation=%f", xplog.epoch, trainerr[#trainerr], validerr[#validerr]))


local function get_ngrams(sent, n, count)
local ngrams = {}
for beg = 1, #sent do
for last= beg, math.min(beg+n-1, #sent) do
local ngram = table.concat(sent, ' ', beg, last)
local len = last-beg+1 -- keep track of ngram length
if not count then
table.insert(ngrams, ngram)
else
if ngrams[ngram] == nil then
ngrams[ngram] = {1, len}
else
ngrams[ngram][1] = ngrams[ngram][1] + 1
end
end
end
end
return ngrams
end

local function get_ngram_prec(cand, ref, n)
local results = {}
for i = 1, n do
results[i] = {0, 0}
end
local cand_ngrams = get_ngrams(cand, n, 1)
local ref_ngrams = get_ngrams(ref, n, 1)
for ngram, dist in pairs(cand_ngrams) do
local freq = dist[1]
local length = dist[2]
results[length][1] = results[length][1] + freq
local actual
if ref_ngrams[ngram] == nil then
actual = 0
else
actual = ref_ngrams[ngram][1]
end
results[length][2] = results[length][2] + math.min(actual, freq)
end
return results
end

function get_bleu(cand, ref, n)
n = n or 4
local smooth = 1
if type(cand) ~= 'table' then
cand = cand:totable()
end
if type(ref) ~= 'table' then
ref = ref:totable()
end
local res = get_ngram_prec(cand, ref, n)
local brevPen = math.exp(1-math.max(1, #ref/#cand))
local correct = 0
local total = 0
local bleu = 1
for i = 1, n do
if res[i][1] > 0 then
if res[i][2] == 0 then
smooth = smooth*0.5
res[i][2] = smooth
end
local prec = res[i][2]/res[i][1]
bleu = bleu * prec
end
end
bleu = bleu^(1/n)
return bleu*brevPen
end

if opt.dumpcsv then
local csvfile = opt.xplogpath:match('([^/]+)[.]t7$')..'.csv'
paths.mkdir('learningcurves')
Expand Down Expand Up @@ -220,7 +149,7 @@ else
if opt.bleu then
max_ind = torch.multinomial(torch.exp(outputs:view(targets:nElement(), -1)), 1):view(targets:size(1),targets:size(2))
for batchIdx=1, targets:size(2) do
sum_bleu = sum_bleu + get_bleu(max_ind:select(2, batchIdx),
sum_bleu = sum_bleu + nn.get_bleu(max_ind:select(2, batchIdx),
targets:select(2, batchIdx),
opt.blueN)
num_sent = num_sent + 1
Expand Down
43 changes: 42 additions & 1 deletion utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,45 @@ function nn.utils.setZeroMask(modules, zeroMask, cuda)
for i,module in ipairs(torch.type(modules) == 'table' and modules or {modules}) do
module:setZeroMask(zeroMask)
end
end
end
function nn.utils.get_ngrams(sent, n, count)
local ngrams = {}
for beg = 1, #sent do
for last= beg, math.min(beg+n-1, #sent) do
local ngram = table.concat(sent, ' ', beg, last)
local len = last-beg+1 -- keep track of ngram length
if not count then
table.insert(ngrams, ngram)
else
if ngrams[ngram] == nil then
ngrams[ngram] = {1, len}
else
ngrams[ngram][1] = ngrams[ngram][1] + 1
end
end
end
end
return ngrams
end

function nn.utils.get_ngram_prec(cand, ref, n)
local results = {}
for i = 1, n do
results[i] = {0, 0}
end
local cand_ngrams = nn.utils.get_ngrams(cand, n, 1)
local ref_ngrams = nn.utils.get_ngrams(ref, n, 1)
for ngram, dist in pairs(cand_ngrams) do
local freq = dist[1]
local length = dist[2]
results[length][1] = results[length][1] + freq
local actual
if ref_ngrams[ngram] == nil then
actual = 0
else
actual = ref_ngrams[ngram][1]
end
results[length][2] = results[length][2] + math.min(actual, freq)
end
return results
end

0 comments on commit f871177

Please sign in to comment.