@@ -1637,6 +1637,8 @@ def test_windowed_nd_big():
16371637
16381638
16391639def naive_slice_nd (x , start , size ):
1640+ # old implementation, check out naive_slice_nd2
1641+
16401642 slices_shape = [x .shape [0 ], size ] + list (x .shape )[2 :]
16411643 ys = numpy .zeros (shape = slices_shape )
16421644 for i in range (len (start )):
@@ -1653,6 +1655,63 @@ def naive_slice_nd(x, start, size):
16531655 return ys
16541656
16551657
1658+ def naive_slice_nd2 (x , start , size ):
1659+ # Assuming that x: [B, T1, T2, .., Tn, D] and start: [B, T1, .., Tn-1]
1660+ # i.e. the dimensions of x and start are ordered accordingly.
1661+ # (Otherwise we should require the slice axis too.)
1662+
1663+ len_common_dims = len (start .shape )
1664+ slice_shape = (size ,) + x .shape [len_common_dims + 1 :]
1665+ result_shape = start .shape [0 :len_common_dims ] + slice_shape # shape of output
1666+ result = numpy .zeros (result_shape )
1667+
1668+ slice_axis_dim = x .shape [len_common_dims ] # dim of axis being sliced
1669+ for index , start_position in numpy .ndenumerate (start ):
1670+ end_position = min (start_position + size , slice_axis_dim ) # padding required
1671+
1672+ # no padding
1673+ padding = ((0 ,0 ),)
1674+ for i in range (1 , len (slice_shape )):
1675+ padding += ((0 , 0 ),)
1676+
1677+ # if required replace the first padding tuple, which corresponds to the slice axis
1678+ if end_position < start_position + size :
1679+ padding = ((0 ,size - end_position + start_position ),) + padding [1 :]
1680+ result [index ] = numpy .pad (x [index ][start_position :end_position ], padding , mode = 'constant' , constant_values = 0 )
1681+ return result
1682+
1683+
1684+ def test_slice_nd_multi_dim ():
1685+ n_batch = 2
1686+ n_time_1 = 2
1687+ n_time_2 = 3 # slice axis
1688+ n_dim = 2
1689+ size = 2
1690+ source = numpy .arange (24 , dtype = numpy .float32 ).reshape (n_batch , n_time_1 , n_time_2 , n_dim ).astype ("float32" )
1691+ start = numpy .array ([[0 ,1 ],[1 ,2 ]]).astype ("int32" )
1692+ naive = naive_slice_nd2 (source , start , size )
1693+ source_tf = tf .constant (source )
1694+ real = slice_nd2 (source_tf , start = start , size = size ).eval ()
1695+ print ("source:" )
1696+ print (source )
1697+ print ("naive:" )
1698+ print (naive )
1699+ print ("real:" )
1700+ print (real )
1701+ expected_output = numpy .array (
1702+ [[[[0 , 1 ],
1703+ [2 , 3 ]],
1704+ [[8 , 9 ],
1705+ [10 , 11 ]]],
1706+
1707+ [[[14 , 15 ],
1708+ [16 , 17 ]],
1709+ [[22 , 23 ],
1710+ [0 , 0 ]]]]) # padding
1711+ numpy .testing .assert_almost_equal (naive , expected_output )
1712+ numpy .testing .assert_almost_equal (real , expected_output )
1713+
1714+
16561715def test_slice_nd_small ():
16571716 n_batch = 3
16581717 n_time = 4
@@ -1662,7 +1721,7 @@ def test_slice_nd_small():
16621721 source = numpy .arange (1 , n_batch * n_time * n_dim + 1 , dtype = numpy .float32 ).reshape (n_batch , n_time , n_dim ).astype ("float32" )
16631722 source_tf = tf .constant (source )
16641723 naive = naive_slice_nd (source , start , size )
1665- real = slice_nd (source_tf , start = start , size = size ).eval ()
1724+ real = slice_nd2 (source_tf , start = start , size = size ).eval ()
16661725 print ("source:" )
16671726 print (source )
16681727 print ("naive:" )
@@ -1682,7 +1741,7 @@ def test_slice_nd_big():
16821741 source = numpy .random .random ((n_batch , n_time , n_dim )).astype ("float32" )
16831742 source_tf = tf .constant (source )
16841743 naive = naive_slice_nd (source , start , size )
1685- real = slice_nd (source_tf , start = start , size = size ).eval ()
1744+ real = slice_nd2 (source_tf , start = start , size = size ).eval ()
16861745 print ("source:" )
16871746 print (source )
16881747 print ("naive:" )
0 commit comments