Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0584c15

Browse files
author
Ryan Sepassi
committed
Update xception_test with correct target modality
PiperOrigin-RevId: 187658067
1 parent 11f6576 commit 0584c15

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tensor2tensor/models/xception_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
class XceptionTest(tf.test.TestCase):
3434

35-
def _testXception(self, img_size, output_size):
35+
def _testXception(self, img_size):
3636
vocab_size = 9
3737
batch_size = 3
3838
x = np.random.random_integers(
@@ -42,6 +42,7 @@ def _testXception(self, img_size, output_size):
4242
hparams = xception.xception_tiny()
4343
p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size)
4444
p_hparams.input_modality["inputs"] = (registry.Modalities.IMAGE, None)
45+
p_hparams.target_modality = (registry.Modalities.CLASS_LABEL, vocab_size)
4546
with self.test_session() as session:
4647
features = {
4748
"inputs": tf.constant(x, dtype=tf.int32),
@@ -51,13 +52,13 @@ def _testXception(self, img_size, output_size):
5152
logits, _ = model(features)
5253
session.run(tf.global_variables_initializer())
5354
res = session.run(logits)
54-
self.assertEqual(res.shape, output_size + (1, vocab_size))
55+
self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size))
5556

56-
def testXceptionSmall(self):
57-
self._testXception(img_size=9, output_size=(3, 5, 5))
57+
def testXceptionSmallImage(self):
58+
self._testXception(img_size=9)
5859

59-
def testXceptionLarge(self):
60-
self._testXception(img_size=256, output_size=(3, 8, 8))
60+
def testXceptionLargeImage(self):
61+
self._testXception(img_size=256)
6162

6263

6364
if __name__ == "__main__":

0 commit comments

Comments
 (0)