Skip to content

Commit def0758

Browse files
Merge pull request #34 from torch/measure
Measure score functions added
2 parents ae3d409 + e3bfcb0 commit def0758

File tree

6 files changed

+218
-74
lines changed

6 files changed

+218
-74
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ SET(luasrc
8484
TotalDropout.lua
8585
VRClassReward.lua
8686
ReverseUnreverse.lua
87+
measure.lua
8788
deprecated/SeqLSTMP.lua
8889
deprecated/SeqReverseSequence.lua
8990
deprecated/BiSequencerLM.lua

init.lua

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ paths.require 'librnn'
2727
unpack = unpack or table.unpack
2828

2929
require('rnn.utils')
30-
3130
-- extensions to existing nn.Module
3231
require('rnn.Module')
3332
require('rnn.Container')
@@ -112,6 +111,9 @@ require('rnn.SeqLSTMP')
112111
require('rnn.SeqReverseSequence')
113112
require('rnn.BiSequencerLM')
114113

114+
115+
require('rnn.measure')
116+
115117
-- prevent likely name conflicts
116118
nn.rnn = rnn
117119

measure.lua

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
function nn.get_bleu(cand, ref, n)
2+
n = n or 4
3+
local smooth = 1
4+
if type(cand) ~= 'table' then
5+
cand = cand:totable()
6+
end
7+
if type(ref) ~= 'table' then
8+
ref = ref:totable()
9+
end
10+
local res = nn.utils.get_ngram_prec(cand, ref, n)
11+
local brevPen = math.exp(1-math.max(1, #ref/#cand))
12+
local correct = 0
13+
local total = 0
14+
local bleu = 1
15+
for i = 1, n do
16+
if res[i][1] > 0 then
17+
if res[i][2] == 0 then
18+
smooth = smooth*0.5
19+
res[i][2] = smooth
20+
end
21+
local prec = res[i][2]/res[i][1]
22+
bleu = bleu * prec
23+
end
24+
end
25+
bleu = bleu^(1/n)
26+
return bleu*brevPen
27+
end
28+
29+
function nn.get_rougeN(cand, ref, n, weight)
30+
n = n or 4
31+
weight = weight or {}
32+
if #weight == 0 then
33+
for i=1, n do
34+
weight[i] = 0
35+
end
36+
weight[n] = 1
37+
end
38+
if type(cand) ~= 'table' then
39+
cand = cand:totable()
40+
end
41+
if type(ref) ~= 'table' then
42+
ref = ref:totable()
43+
end
44+
local res = nn.utils.get_ngram_recall(cand, ref, n)
45+
local correct = 0
46+
local total = 0
47+
local rouge = 0
48+
weight_sum = 0
49+
50+
for i = 1, n do
51+
local recall = res[i][2]/res[i][1]
52+
rouge = rouge + recall*weight[i]
53+
weight_sum = weight_sum + weight[i]
54+
end
55+
rouge = rouge/weight_sum
56+
return rouge
57+
end
58+
59+
function nn.get_rougeS(cand, ref, beta, dskip)
60+
local beta = beta or 1
61+
beta = beta * beta
62+
63+
local dskip = dskip or (#cand)
64+
dskip = math.min(dskip, #cand)
65+
if type(cand) ~= 'table' then
66+
cand = cand:totable()
67+
end
68+
if type(ref) ~= 'table' then
69+
ref = ref:totable()
70+
end
71+
local cand_unigrams = nn.utils.get_ngrams(cand, 1)
72+
local ref_unigrams = nn.utils.get_ngrams(ref, 1)
73+
74+
local cand_skip_bigrams = nn.utils.get_skip_bigrams(cand, ref_unigrams, 1, dskip)
75+
local ref_skip_bigrams = nn.utils.get_skip_bigrams(ref, cand_unigrams, 1, dskip)
76+
local correct = 0
77+
78+
for bigram, freq in pairs(ref_skip_bigrams) do
79+
local actual
80+
if cand_skip_bigrams[bigram] == nil then
81+
actual = 0
82+
else
83+
actual = cand_skip_bigrams[bigram]
84+
end
85+
correct = correct + math.min(actual, freq)
86+
end
87+
local total_skip_bigrams_ref = (dskip - 1)*(2 * #ref - dskip)/2
88+
local total_skip_bigrams_cand = (dskip - 1)*(2 * #cand - dskip)/2
89+
local rskip2 = correct/total_skip_bigrams_cand
90+
local pskip2 = correct/total_skip_bigrams_ref
91+
local rouge = (1 + beta)*rskip2*pskip2/(rskip2 + beta*pskip2)
92+
return rouge
93+
end
94+

scripts/evaluate-rnnlm.lua

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -42,77 +42,6 @@ local validerr = xplog.valnceloss or xplog.valppl
4242

4343
print(string.format("Error (epoch=%d): training=%f; validation=%f", xplog.epoch, trainerr[#trainerr], validerr[#validerr]))
4444

45-
46-
local function get_ngrams(sent, n, count)
47-
local ngrams = {}
48-
for beg = 1, #sent do
49-
for last= beg, math.min(beg+n-1, #sent) do
50-
local ngram = table.concat(sent, ' ', beg, last)
51-
local len = last-beg+1 -- keep track of ngram length
52-
if not count then
53-
table.insert(ngrams, ngram)
54-
else
55-
if ngrams[ngram] == nil then
56-
ngrams[ngram] = {1, len}
57-
else
58-
ngrams[ngram][1] = ngrams[ngram][1] + 1
59-
end
60-
end
61-
end
62-
end
63-
return ngrams
64-
end
65-
66-
local function get_ngram_prec(cand, ref, n)
67-
local results = {}
68-
for i = 1, n do
69-
results[i] = {0, 0}
70-
end
71-
local cand_ngrams = get_ngrams(cand, n, 1)
72-
local ref_ngrams = get_ngrams(ref, n, 1)
73-
for ngram, dist in pairs(cand_ngrams) do
74-
local freq = dist[1]
75-
local length = dist[2]
76-
results[length][1] = results[length][1] + freq
77-
local actual
78-
if ref_ngrams[ngram] == nil then
79-
actual = 0
80-
else
81-
actual = ref_ngrams[ngram][1]
82-
end
83-
results[length][2] = results[length][2] + math.min(actual, freq)
84-
end
85-
return results
86-
end
87-
88-
function get_bleu(cand, ref, n)
89-
n = n or 4
90-
local smooth = 1
91-
if type(cand) ~= 'table' then
92-
cand = cand:totable()
93-
end
94-
if type(ref) ~= 'table' then
95-
ref = ref:totable()
96-
end
97-
local res = get_ngram_prec(cand, ref, n)
98-
local brevPen = math.exp(1-math.max(1, #ref/#cand))
99-
local correct = 0
100-
local total = 0
101-
local bleu = 1
102-
for i = 1, n do
103-
if res[i][1] > 0 then
104-
if res[i][2] == 0 then
105-
smooth = smooth*0.5
106-
res[i][2] = smooth
107-
end
108-
local prec = res[i][2]/res[i][1]
109-
bleu = bleu * prec
110-
end
111-
end
112-
bleu = bleu^(1/n)
113-
return bleu*brevPen
114-
end
115-
11645
if opt.dumpcsv then
11746
local csvfile = opt.xplogpath:match('([^/]+)[.]t7$')..'.csv'
11847
paths.mkdir('learningcurves')
@@ -220,7 +149,7 @@ else
220149
if opt.bleu then
221150
max_ind = torch.multinomial(torch.exp(outputs:view(targets:nElement(), -1)), 1):view(targets:size(1),targets:size(2))
222151
for batchIdx=1, targets:size(2) do
223-
sum_bleu = sum_bleu + get_bleu(max_ind:select(2, batchIdx),
152+
sum_bleu = sum_bleu + nn.get_bleu(max_ind:select(2, batchIdx),
224153
targets:select(2, batchIdx),
225154
opt.blueN)
226155
num_sent = num_sent + 1

test/test.lua

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6464,6 +6464,39 @@ function rnntest.NCE_multicuda()
64646464
mytester:assertTensorEq(nce2.gradWeight[{{},{1+(hiddensize/2), hiddensize}}]:float(), nce.gradWeight.tensors[2]:float(), 0.000001)
64656465
end
64666466

6467+
function rnntest.bleu()
6468+
local cand = {1, 2, 3, 2, 3, 2, 3, 4, 5, 2, 1, 3 ,2 ,3}
6469+
local ref = {3, 2, 3, 2, 1, 3, 2, 3 ,2 , 3, 4, 5, 2, 1, 3}
6470+
local bleu = nn.get_bleu(cand, ref, 4)
6471+
mytester:assert(math.abs(bleu - 0.83101069788036) < 0.000001)
6472+
6473+
end
6474+
6475+
function rnntest.get_rougeN()
6476+
6477+
local cand = {1, 2, 3, 2, 3, 2, 3, 4, 5, 2, 1, 3 ,2 ,3}
6478+
local ref = {3, 2, 3, 2, 1, 3, 2, 3 ,2 , 3, 4, 5, 2, 1, 3}
6479+
local rouge = nn.get_rougeN(cand, ref, 4)
6480+
mytester:assert(math.abs(rouge - 0.75)< 0.000000000001)
6481+
end
6482+
6483+
6484+
function rnntest.get_rougeS()
6485+
local cand_tbl = {
6486+
'police kill the gunman',
6487+
'the gunman kill police',
6488+
'the gunman police killed'
6489+
}
6490+
local ref_str = "police killed the gunman"
6491+
local ref = ref_str:split(" ")
6492+
local rouge = nn.get_rougeS(cand_tbl[1]:split(" "), ref, 1, 4)
6493+
mytester:assert(math.abs(rouge - 1/2)< 0.000000000001)
6494+
local rouge = nn.get_rougeS(cand_tbl[2]:split(" "), ref, 1, 4)
6495+
mytester:assert(math.abs(rouge - 1/6)< 0.000000000001)
6496+
local rouge = nn.get_rougeS(cand_tbl[3]:split(" "), ref, 1, 4)
6497+
mytester:assert(math.abs(rouge - 1/3)< 0.000000000001)
6498+
end
6499+
64676500
function rnn.test(tests, exclude, benchmark_)
64686501
benchmark = benchmark_
64696502
mytester = torch.Tester()

utils.lua

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,4 +290,89 @@ function nn.utils.setZeroMask(modules, zeroMask, cuda)
290290
for i,module in ipairs(torch.type(modules) == 'table' and modules or {modules}) do
291291
module:setZeroMask(zeroMask)
292292
end
293-
end
293+
end
294+
function nn.utils.get_ngrams(sent, n, count)
295+
local ngrams = {}
296+
for beg = 1, #sent do
297+
for last= beg, math.min(beg+n-1, #sent) do
298+
local ngram = table.concat(sent, ' ', beg, last)
299+
local len = last-beg+1 -- keep track of ngram length
300+
if not count then
301+
ngrams[ngram] = 1
302+
else
303+
if ngrams[ngram] == nil then
304+
ngrams[ngram] = {1, len}
305+
else
306+
ngrams[ngram][1] = ngrams[ngram][1] + 1
307+
end
308+
end
309+
end
310+
end
311+
return ngrams
312+
end
313+
314+
function nn.utils.get_skip_bigrams(sent, ref, count, dskip)
315+
local skip_bigrams = {}
316+
ref = ref or sent
317+
for beg = 1, #sent do
318+
if ref[sent[beg]] then
319+
local temp_token = sent[beg]
320+
for last= beg+1, math.min(beg + dskip-1, #sent) do
321+
if ref[sent[last]] then
322+
skip_bigram = temp_token..sent[last]
323+
if not count then
324+
skip_bigrams[skip_bigram] = 1
325+
else
326+
skip_bigrams[skip_bigram] = (skip_bigram[bigram] or 0) + 1
327+
end
328+
end
329+
end
330+
end
331+
end
332+
return skip_bigrams
333+
end
334+
335+
336+
function nn.utils.get_ngram_prec(cand, ref, n)
337+
local results = {}
338+
for i = 1, n do
339+
results[i] = {0, 0}
340+
end
341+
local cand_ngrams = nn.utils.get_ngrams(cand, n, 1)
342+
local ref_ngrams = nn.utils.get_ngrams(ref, n, 1)
343+
for ngram, dist in pairs(cand_ngrams) do
344+
local freq = dist[1]
345+
local length = dist[2]
346+
results[length][1] = results[length][1] + freq
347+
local actual
348+
if ref_ngrams[ngram] == nil then
349+
actual = 0
350+
else
351+
actual = ref_ngrams[ngram][1]
352+
end
353+
results[length][2] = results[length][2] + math.min(actual, freq)
354+
end
355+
return results
356+
end
357+
358+
function nn.utils.get_ngram_recall(cand, ref, n)
359+
local results = {}
360+
for i = 1, n do
361+
results[i] = {0, 0}
362+
end
363+
local cand_ngrams = nn.utils.get_ngrams(cand, n, 1)
364+
local ref_ngrams = nn.utils.get_ngrams(ref, n, 1)
365+
for ngram, dist in pairs(ref_ngrams) do
366+
local freq = dist[1]
367+
local length = dist[2]
368+
results[length][1] = results[length][1] + freq
369+
local actual
370+
if cand_ngrams[ngram] == nil then
371+
actual = 0
372+
else
373+
actual = cand_ngrams[ngram][1]
374+
end
375+
results[length][2] = results[length][2] + math.min(actual, freq)
376+
end
377+
return results
378+
end

0 commit comments

Comments
 (0)