diff --git a/CMakeLists.txt b/CMakeLists.txt index dc0fc41..44a6ff9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,7 @@ SET(luasrc TotalDropout.lua VRClassReward.lua ReverseUnreverse.lua + measure.lua deprecated/SeqLSTMP.lua deprecated/SeqReverseSequence.lua deprecated/BiSequencerLM.lua diff --git a/init.lua b/init.lua index 2c14b85..eb127e2 100644 --- a/init.lua +++ b/init.lua @@ -27,7 +27,6 @@ paths.require 'librnn' unpack = unpack or table.unpack require('rnn.utils') - -- extensions to existing nn.Module require('rnn.Module') require('rnn.Container') @@ -112,6 +111,9 @@ require('rnn.SeqLSTMP') require('rnn.SeqReverseSequence') require('rnn.BiSequencerLM') + +require('rnn.measure') + -- prevent likely name conflicts nn.rnn = rnn diff --git a/measure.lua b/measure.lua new file mode 100644 index 0000000..ad32970 --- /dev/null +++ b/measure.lua @@ -0,0 +1,27 @@ +function nn.get_bleu(cand, ref, n) + n = n or 4 + local smooth = 1 + if type(cand) ~= 'table' then + cand = cand:totable() + end + if type(ref) ~= 'table' then + ref = ref:totable() + end + local res = nn.utils.get_ngram_prec(cand, ref, n) + local brevPen = math.exp(1-math.max(1, #ref/#cand)) + local correct = 0 + local total = 0 + local bleu = 1 + for i = 1, n do + if res[i][1] > 0 then + if res[i][2] == 0 then + smooth = smooth*0.5 + res[i][2] = smooth + end + local prec = res[i][2]/res[i][1] + bleu = bleu * prec + end + end + bleu = bleu^(1/n) + return bleu*brevPen +end diff --git a/scripts/evaluate-rnnlm.lua b/scripts/evaluate-rnnlm.lua index 92d2306..4d823d6 100644 --- a/scripts/evaluate-rnnlm.lua +++ b/scripts/evaluate-rnnlm.lua @@ -42,77 +42,6 @@ local validerr = xplog.valnceloss or xplog.valppl print(string.format("Error (epoch=%d): training=%f; validation=%f", xplog.epoch, trainerr[#trainerr], validerr[#validerr])) - -local function get_ngrams(sent, n, count) - local ngrams = {} - for beg = 1, #sent do - for last= beg, math.min(beg+n-1, #sent) do - local ngram = table.concat(sent, ' ', beg, last) - local len = last-beg+1 -- keep track of ngram length - if not count then - table.insert(ngrams, ngram) - else - if ngrams[ngram] == nil then - ngrams[ngram] = {1, len} - else - ngrams[ngram][1] = ngrams[ngram][1] + 1 - end - end - end - end - return ngrams -end - -local function get_ngram_prec(cand, ref, n) - local results = {} - for i = 1, n do - results[i] = {0, 0} - end - local cand_ngrams = get_ngrams(cand, n, 1) - local ref_ngrams = get_ngrams(ref, n, 1) - for ngram, dist in pairs(cand_ngrams) do - local freq = dist[1] - local length = dist[2] - results[length][1] = results[length][1] + freq - local actual - if ref_ngrams[ngram] == nil then - actual = 0 - else - actual = ref_ngrams[ngram][1] - end - results[length][2] = results[length][2] + math.min(actual, freq) - end - return results -end - -function get_bleu(cand, ref, n) - n = n or 4 - local smooth = 1 - if type(cand) ~= 'table' then - cand = cand:totable() - end - if type(ref) ~= 'table' then - ref = ref:totable() - end - local res = get_ngram_prec(cand, ref, n) - local brevPen = math.exp(1-math.max(1, #ref/#cand)) - local correct = 0 - local total = 0 - local bleu = 1 - for i = 1, n do - if res[i][1] > 0 then - if res[i][2] == 0 then - smooth = smooth*0.5 - res[i][2] = smooth - end - local prec = res[i][2]/res[i][1] - bleu = bleu * prec - end - end - bleu = bleu^(1/n) - return bleu*brevPen -end - if opt.dumpcsv then local csvfile = opt.xplogpath:match('([^/]+)[.]t7$')..'.csv' paths.mkdir('learningcurves') @@ -220,7 +149,7 @@ else if opt.bleu then max_ind = torch.multinomial(torch.exp(outputs:view(targets:nElement(), -1)), 1):view(targets:size(1),targets:size(2)) for batchIdx=1, targets:size(2) do - sum_bleu = sum_bleu + get_bleu(max_ind:select(2, batchIdx), + sum_bleu = sum_bleu + nn.get_bleu(max_ind:select(2, batchIdx), targets:select(2, batchIdx), opt.blueN) num_sent = num_sent + 1 diff --git a/utils.lua b/utils.lua index 4c8eac1..fd461a4 100644 --- a/utils.lua +++ b/utils.lua @@ -290,4 +290,45 @@ function nn.utils.setZeroMask(modules, zeroMask, cuda) for i,module in ipairs(torch.type(modules) == 'table' and modules or {modules}) do module:setZeroMask(zeroMask) end -end \ No newline at end of file +end +function nn.utils.get_ngrams(sent, n, count) + local ngrams = {} + for beg = 1, #sent do + for last= beg, math.min(beg+n-1, #sent) do + local ngram = table.concat(sent, ' ', beg, last) + local len = last-beg+1 -- keep track of ngram length + if not count then + table.insert(ngrams, ngram) + else + if ngrams[ngram] == nil then + ngrams[ngram] = {1, len} + else + ngrams[ngram][1] = ngrams[ngram][1] + 1 + end + end + end + end + return ngrams +end + +function nn.utils.get_ngram_prec(cand, ref, n) + local results = {} + for i = 1, n do + results[i] = {0, 0} + end + local cand_ngrams = nn.utils.get_ngrams(cand, n, 1) + local ref_ngrams = nn.utils.get_ngrams(ref, n, 1) + for ngram, dist in pairs(cand_ngrams) do + local freq = dist[1] + local length = dist[2] + results[length][1] = results[length][1] + freq + local actual + if ref_ngrams[ngram] == nil then + actual = 0 + else + actual = ref_ngrams[ngram][1] + end + results[length][2] = results[length][2] + math.min(actual, freq) + end + return results +end