3232
3333class XceptionTest (tf .test .TestCase ):
3434
35- def _testXception (self , img_size , output_size ):
35+ def _testXception (self , img_size ):
3636 vocab_size = 9
3737 batch_size = 3
3838 x = np .random .random_integers (
@@ -42,6 +42,7 @@ def _testXception(self, img_size, output_size):
4242 hparams = xception .xception_tiny ()
4343 p_hparams = problem_hparams .test_problem_hparams (vocab_size , vocab_size )
4444 p_hparams .input_modality ["inputs" ] = (registry .Modalities .IMAGE , None )
45+ p_hparams .target_modality = (registry .Modalities .CLASS_LABEL , vocab_size )
4546 with self .test_session () as session :
4647 features = {
4748 "inputs" : tf .constant (x , dtype = tf .int32 ),
@@ -51,13 +52,13 @@ def _testXception(self, img_size, output_size):
5152 logits , _ = model (features )
5253 session .run (tf .global_variables_initializer ())
5354 res = session .run (logits )
54- self .assertEqual (res .shape , output_size + ( 1 , vocab_size ))
55+ self .assertEqual (res .shape , ( batch_size , 1 , 1 , 1 , vocab_size ))
5556
56- def testXceptionSmall (self ):
57- self ._testXception (img_size = 9 , output_size = ( 3 , 5 , 5 ) )
57+ def testXceptionSmallImage (self ):
58+ self ._testXception (img_size = 9 )
5859
59- def testXceptionLarge (self ):
60- self ._testXception (img_size = 256 , output_size = ( 3 , 8 , 8 ) )
60+ def testXceptionLargeImage (self ):
61+ self ._testXception (img_size = 256 )
6162
6263
6364if __name__ == "__main__" :
0 commit comments