@@ -6464,6 +6464,39 @@ function rnntest.NCE_multicuda()
6464
6464
mytester :assertTensorEq (nce2 .gradWeight [{{},{1 + (hiddensize / 2 ), hiddensize }}]:float (), nce .gradWeight .tensors [2 ]:float (), 0.000001 )
6465
6465
end
6466
6466
6467
+ function rnntest .bleu ()
6468
+ local cand = {1 , 2 , 3 , 2 , 3 , 2 , 3 , 4 , 5 , 2 , 1 , 3 ,2 ,3 }
6469
+ local ref = {3 , 2 , 3 , 2 , 1 , 3 , 2 , 3 ,2 , 3 , 4 , 5 , 2 , 1 , 3 }
6470
+ local bleu = nn .get_bleu (cand , ref , 4 )
6471
+ mytester :assert (math.abs (bleu - 0.83101069788036 ) < 0.000001 )
6472
+
6473
+ end
6474
+
6475
+ function rnntest .get_rougeN ()
6476
+
6477
+ local cand = {1 , 2 , 3 , 2 , 3 , 2 , 3 , 4 , 5 , 2 , 1 , 3 ,2 ,3 }
6478
+ local ref = {3 , 2 , 3 , 2 , 1 , 3 , 2 , 3 ,2 , 3 , 4 , 5 , 2 , 1 , 3 }
6479
+ local rouge = nn .get_rougeN (cand , ref , 4 )
6480
+ mytester :assert (math.abs (rouge - 0.75 )< 0.000000000001 )
6481
+ end
6482
+
6483
+
6484
+ function rnntest .get_rougeS ()
6485
+ local cand_tbl = {
6486
+ ' police kill the gunman' ,
6487
+ ' the gunman kill police' ,
6488
+ ' the gunman police killed'
6489
+ }
6490
+ local ref_str = " police killed the gunman"
6491
+ local ref = ref_str :split (" " )
6492
+ local rouge = nn .get_rougeS (cand_tbl [1 ]:split (" " ), ref , 1 , 4 )
6493
+ mytester :assert (math.abs (rouge - 1 / 2 )< 0.000000000001 )
6494
+ local rouge = nn .get_rougeS (cand_tbl [2 ]:split (" " ), ref , 1 , 4 )
6495
+ mytester :assert (math.abs (rouge - 1 / 6 )< 0.000000000001 )
6496
+ local rouge = nn .get_rougeS (cand_tbl [3 ]:split (" " ), ref , 1 , 4 )
6497
+ mytester :assert (math.abs (rouge - 1 / 3 )< 0.000000000001 )
6498
+ end
6499
+
6467
6500
function rnn .test (tests , exclude , benchmark_ )
6468
6501
benchmark = benchmark_
6469
6502
mytester = torch .Tester ()
0 commit comments