|
7 | 7 |
|
8 | 8 | class NoisyImageGenerator(Sequence):
|
9 | 9 | def __init__(self, image_dir, source_noise_model, target_noise_model, batch_size=32, image_size=64):
|
10 |
| - self.image_paths = list(Path(image_dir).glob("*.jpg")) |
| 10 | + image_suffixes = (".jpeg", ".jpg", ".png", "bmp") |
| 11 | + self.image_paths = [p for p in Path(image_dir).glob("**/*") if p.suffix.lower() in image_suffixes] |
11 | 12 | self.source_noise_model = source_noise_model
|
12 | 13 | self.target_noise_model = target_noise_model
|
13 | 14 | self.image_num = len(self.image_paths)
|
14 | 15 | self.batch_size = batch_size
|
15 | 16 | self.image_size = image_size
|
16 | 17 |
|
| 18 | + if self.image_num == 0: |
| 19 | + raise ValueError("image dir '{}' does not include any image".format(image_dir)) |
| 20 | + |
17 | 21 | def __len__(self):
|
18 | 22 | return self.image_num // self.batch_size
|
19 | 23 |
|
@@ -45,10 +49,14 @@ def __getitem__(self, idx):
|
45 | 49 |
|
46 | 50 | class ValGenerator(Sequence):
|
47 | 51 | def __init__(self, image_dir, val_noise_model):
|
48 |
| - image_paths = list(Path(image_dir).glob("*.*")) |
| 52 | + image_suffixes = (".jpeg", ".jpg", ".png", "bmp") |
| 53 | + image_paths = [p for p in Path(image_dir).glob("**/*") if p.suffix.lower() in image_suffixes] |
49 | 54 | self.image_num = len(image_paths)
|
50 | 55 | self.data = []
|
51 | 56 |
|
| 57 | + if self.image_num == 0: |
| 58 | + raise ValueError("image dir '{}' does not include any image".format(image_dir)) |
| 59 | + |
52 | 60 | for image_path in image_paths:
|
53 | 61 | y = cv2.imread(str(image_path))
|
54 | 62 | h, w, _ = y.shape
|
|
0 commit comments