@@ -51,15 +51,17 @@ def resize_by_area(img, size):
5151
5252class ImageProblem (problem .Problem ):
5353
54- def example_reading_spec (self , label_key = None ):
55- if label_key is None :
56- label_key = "image/class/label"
54+ def example_reading_spec (self , label_repr = None ):
55+ if label_repr is None :
56+ label_repr = ( "image/class/label" , tf . FixedLenFeature (( 1 ,), tf . int64 ))
5757
5858 data_fields = {
5959 "image/encoded" : tf .FixedLenFeature ((), tf .string ),
6060 "image/format" : tf .FixedLenFeature ((), tf .string ),
61- label_key : tf .VarLenFeature (tf .int64 )
6261 }
62+ label_key , label_type = label_repr # pylint: disable=unpacking-non-sequence
63+ data_fields [label_key ] = label_type
64+
6365 data_items_to_decoders = {
6466 "inputs" :
6567 tf .contrib .slim .tfexample_decoder .Image (
@@ -244,8 +246,9 @@ def hparams(self, defaults, unused_model_hparams):
244246
245247 def example_reading_spec (self ):
246248 label_key = "image/unpadded_label"
249+ label_type = tf .VarLenFeature (tf .int64 )
247250 return super (ImageFSNS , self ).example_reading_spec (
248- self , label_key = label_key )
251+ self , label_repr = ( label_key , label_type ) )
249252
250253
251254class Image2ClassProblem (ImageProblem ):
@@ -283,10 +286,8 @@ def generator(self, data_dir, tmp_dir, is_training):
283286
284287 def hparams (self , defaults , unused_model_hparams ):
285288 p = defaults
286- small_modality = "%s:small_image_modality" % registry .Modalities .IMAGE
287- modality = small_modality if self .is_small else registry .Modalities .IMAGE
288- p .input_modality = {"inputs" : (modality , None )}
289- p .target_modality = ("%s:2d" % registry .Modalities .CLASS_LABEL ,
289+ p .input_modality = {"inputs" : (registry .Modalities .IMAGE , None )}
290+ p .target_modality = (registry .Modalities .CLASS_LABEL ,
290291 self .num_classes )
291292 p .batch_size_multiplier = 4 if self .is_small else 256
292293 p .max_expected_batch_size_per_shard = 8 if self .is_small else 2
@@ -382,6 +383,38 @@ def preprocess_example(self, example, mode, unused_hparams):
382383 return example
383384
384385
386+ @registry .register_problem
387+ class ImageImagenet64 (Image2ClassProblem ):
388+ """Imagenet rescaled to 64x64."""
389+
390+ def dataset_filename (self ):
391+ return "image_imagenet" # Reuse Imagenet data.
392+
393+ @property
394+ def is_small (self ):
395+ return True # Modalities like for CIFAR.
396+
397+ @property
398+ def num_classes (self ):
399+ return 1000
400+
401+ def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
402+ # TODO(lukaszkaiser): find a better way than printing this.
403+ print ("To generate the ImageNet dataset in the proper format, follow "
404+ "instructions at https://github.com/tensorflow/models/blob/master"
405+ "/inception/README.md#getting-started" )
406+
407+ def preprocess_example (self , example , mode , unused_hparams ):
408+ inputs = example ["inputs" ]
409+ # Just resize with area.
410+ if self ._was_reversed :
411+ example ["inputs" ] = resize_by_area (inputs , 64 )
412+ else :
413+ example = imagenet_preprocess_example (example , mode )
414+ example ["inputs" ] = example ["inputs" ] = resize_by_area (inputs , 64 )
415+ return example
416+
417+
385418@registry .register_problem
386419class Img2imgImagenet (ImageProblem ):
387420 """Imagenet rescaled to 8x8 for input and 32x32 for output."""
@@ -623,9 +656,11 @@ def class_labels(self):
623656 ]
624657
625658 def preprocess_example (self , example , mode , unused_hparams ):
659+ example ["inputs" ].set_shape ([_CIFAR10_IMAGE_SIZE , _CIFAR10_IMAGE_SIZE , 3 ])
626660 if mode == tf .estimator .ModeKeys .TRAIN :
627661 example ["inputs" ] = common_layers .cifar_image_augmentation (
628662 example ["inputs" ])
663+ example ["inputs" ] = tf .to_int64 (example ["inputs" ])
629664 return example
630665
631666 def generator (self , data_dir , tmp_dir , is_training ):
@@ -649,6 +684,7 @@ def generator(self, data_dir, tmp_dir, is_training):
649684class ImageCifar10Plain (ImageCifar10 ):
650685
651686 def preprocess_example (self , example , mode , unused_hparams ):
687+ example ["inputs" ].set_shape ([_CIFAR10_IMAGE_SIZE , _CIFAR10_IMAGE_SIZE , 3 ])
652688 example ["inputs" ] = tf .to_int64 (example ["inputs" ])
653689 return example
654690
0 commit comments