Skip to content

Commit 1d60a49

Browse files
author
yusuke-a-uchida
committed
check image num
1 parent 184a6bd commit 1d60a49

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

generator.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77

88
class NoisyImageGenerator(Sequence):
99
def __init__(self, image_dir, source_noise_model, target_noise_model, batch_size=32, image_size=64):
10-
self.image_paths = list(Path(image_dir).glob("*.jpg"))
10+
image_suffixes = (".jpeg", ".jpg", ".png", "bmp")
11+
self.image_paths = [p for p in Path(image_dir).glob("**/*") if p.suffix.lower() in image_suffixes]
1112
self.source_noise_model = source_noise_model
1213
self.target_noise_model = target_noise_model
1314
self.image_num = len(self.image_paths)
1415
self.batch_size = batch_size
1516
self.image_size = image_size
1617

18+
if self.image_num == 0:
19+
raise ValueError("image dir '{}' does not include any image".format(image_dir))
20+
1721
def __len__(self):
1822
return self.image_num // self.batch_size
1923

@@ -45,10 +49,14 @@ def __getitem__(self, idx):
4549

4650
class ValGenerator(Sequence):
4751
def __init__(self, image_dir, val_noise_model):
48-
image_paths = list(Path(image_dir).glob("*.*"))
52+
image_suffixes = (".jpeg", ".jpg", ".png", "bmp")
53+
image_paths = [p for p in Path(image_dir).glob("**/*") if p.suffix.lower() in image_suffixes]
4954
self.image_num = len(image_paths)
5055
self.data = []
5156

57+
if self.image_num == 0:
58+
raise ValueError("image dir '{}' does not include any image".format(image_dir))
59+
5260
for image_path in image_paths:
5361
y = cv2.imread(str(image_path))
5462
h, w, _ = y.shape

0 commit comments

Comments
 (0)