@@ -77,12 +77,17 @@ def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
77
77
test_files = _CIFAR10_TEST_FILES
78
78
prefix = _CIFAR10_PREFIX
79
79
image_size = _CIFAR10_IMAGE_SIZE
80
- elif cifar_version == "cifar100" :
80
+ label_key = "labels"
81
+ elif cifar_version == "cifar100" or cifar_version == "cifar20" :
81
82
url = _CIFAR100_URL
82
83
train_files = _CIFAR100_TRAIN_FILES
83
84
test_files = _CIFAR100_TEST_FILES
84
85
prefix = _CIFAR100_PREFIX
85
86
image_size = _CIFAR100_IMAGE_SIZE
87
+ if cifar_version == "cifar100" :
88
+ label_key = "fine_labels"
89
+ else :
90
+ label_key = "coarse_labels"
86
91
87
92
_get_cifar (tmp_dir , url )
88
93
data_files = train_files if training else test_files
@@ -97,7 +102,7 @@ def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
97
102
all_images .extend ([
98
103
np .squeeze (images [j ]).transpose ((1 , 2 , 0 )) for j in xrange (num_images )
99
104
])
100
- labels = data ["labels" if cifar_version == "cifar10" else "fine_labels" ]
105
+ labels = data [label_key ]
101
106
all_labels .extend ([labels [j ] for j in xrange (num_images )])
102
107
return image_utils .image_generator (
103
108
all_images [start_from :start_from + how_many ],
@@ -417,3 +422,108 @@ def hparams(self, defaults, unused_model_hparams):
417
422
p .max_expected_batch_size_per_shard = 4
418
423
p .input_space_id = 1
419
424
p .target_space_id = 1
425
+
426
+
427
+ @registry .register_problem
428
+ class ImageCifar20Tune (mnist .ImageMnistTune ):
429
+ """Cifar-20 Tune."""
430
+
431
+ @property
432
+ def num_classes (self ):
433
+ return 20
434
+
435
+ @property
436
+ def num_channels (self ):
437
+ return 3
438
+
439
+ @property
440
+ def class_labels (self ):
441
+ return [
442
+ "aquatic mammals" ,
443
+ "fish" ,
444
+ "flowers" ,
445
+ "food containers" ,
446
+ "fruit and vegetables" ,
447
+ "household electrical devices" ,
448
+ "household furniture" ,
449
+ "insects" ,
450
+ "large carnivores" ,
451
+ "large man-made outdoor things" ,
452
+ "large natural outdoor scenes" ,
453
+ "large omnivores and herbivores" ,
454
+ "medium-sized mammals" ,
455
+ "non-insect invertebrates" ,
456
+ "people" ,
457
+ "reptiles" ,
458
+ "small mammals" ,
459
+ "trees" ,
460
+ "vehicles 1" ,
461
+ "vehicles 2" ,
462
+ ]
463
+
464
+ def preprocess_example (self , example , mode , unused_hparams ):
465
+ image = example ["inputs" ]
466
+ image .set_shape ([_CIFAR100_IMAGE_SIZE , _CIFAR100_IMAGE_SIZE , 3 ])
467
+ if mode == tf .estimator .ModeKeys .TRAIN :
468
+ image = image_utils .cifar_image_augmentation (image )
469
+ if not self ._was_reversed :
470
+ image = tf .image .per_image_standardization (image )
471
+ example ["inputs" ] = image
472
+ return example
473
+
474
+ def generator (self , data_dir , tmp_dir , is_training ):
475
+ if is_training :
476
+ return cifar_generator ("cifar20" , tmp_dir , True , 48000 )
477
+ else :
478
+ return cifar_generator ("cifar20" , tmp_dir , True , 2000 , 48000 )
479
+
480
+
481
+ @registry .register_problem
482
+ class ImageCifar20 (ImageCifar20Tune ):
483
+
484
+ def generator (self , data_dir , tmp_dir , is_training ):
485
+ if is_training :
486
+ return cifar_generator ("cifar20" , tmp_dir , True , 50000 )
487
+ else :
488
+ return cifar_generator ("cifar20" , tmp_dir , False , 10000 )
489
+
490
+
491
+ @registry .register_problem
492
+ class ImageCifar20Plain (ImageCifar20 ):
493
+
494
+ def preprocess_example (self , example , mode , unused_hparams ):
495
+ image = example ["inputs" ]
496
+ image .set_shape ([_CIFAR100_IMAGE_SIZE , _CIFAR100_IMAGE_SIZE , 3 ])
497
+ if not self ._was_reversed :
498
+ image = tf .image .per_image_standardization (image )
499
+ example ["inputs" ] = image
500
+ return example
501
+
502
+
503
+ @registry .register_problem
504
+ class ImageCifar20PlainGen (ImageCifar20Plain ):
505
+ """CIFAR-20 32x32 for image generation without standardization preprep."""
506
+
507
+ def dataset_filename (self ):
508
+ return "image_cifar20_plain" # Reuse CIFAR-20 plain data.
509
+
510
+ def preprocess_example (self , example , mode , unused_hparams ):
511
+ example ["inputs" ].set_shape ([_CIFAR100_IMAGE_SIZE , _CIFAR100_IMAGE_SIZE , 3 ])
512
+ example ["inputs" ] = tf .to_int64 (example ["inputs" ])
513
+ return example
514
+
515
+
516
+ @registry .register_problem
517
+ class ImageCifar20Plain8 (ImageCifar20 ):
518
+ """CIFAR-20 rescaled to 8x8 for output: Conditional image generation."""
519
+
520
+ def dataset_filename (self ):
521
+ return "image_cifar20_plain" # Reuse CIFAR-20 plain data.
522
+
523
+ def preprocess_example (self , example , mode , unused_hparams ):
524
+ image = example ["inputs" ]
525
+ image = image_utils .resize_by_area (image , 8 )
526
+ if not self ._was_reversed :
527
+ image = tf .image .per_image_standardization (image )
528
+ example ["inputs" ] = image
529
+ return example
0 commit comments