2
2
import numpy as np
3
3
from pathlib import Path
4
4
import cv2
5
- from model import get_srresnet_model
5
+ from model import get_model
6
6
from noise_model import get_noise_model
7
7
8
8
@@ -11,6 +11,8 @@ def get_args():
11
11
formatter_class = argparse .ArgumentDefaultsHelpFormatter )
12
12
parser .add_argument ("--image_dir" , type = str , required = True ,
13
13
help = "test image dir" )
14
+ parser .add_argument ("--model" , type = str , default = "srresnet" ,
15
+ help = "model architecture ('srresnet' or 'unet')" )
14
16
parser .add_argument ("--weight_file" , type = str , required = True ,
15
17
help = "trained weight file" )
16
18
parser .add_argument ("--test_noise_model" , type = str , default = "gaussian,25,25" ,
@@ -31,7 +33,7 @@ def main():
31
33
image_dir = args .image_dir
32
34
weight_file = args .weight_file
33
35
val_noise_model = get_noise_model (args .test_noise_model )
34
- model = get_srresnet_model ( )
36
+ model = get_model ( args . model )
35
37
model .load_weights (weight_file )
36
38
37
39
if args .output_dir :
@@ -43,6 +45,9 @@ def main():
43
45
for image_path in image_paths :
44
46
image = cv2 .imread (str (image_path ))
45
47
h , w , _ = image .shape
48
+ image = image [:(h // 16 ) * 16 , :(w // 16 ) * 16 ] # for stride (maximum 16)
49
+ h , w , _ = image .shape
50
+
46
51
out_image = np .zeros ((h , w * 3 , 3 ), dtype = np .uint8 )
47
52
noise_image = val_noise_model (image )
48
53
pred = model .predict (np .expand_dims (noise_image , 0 ))
0 commit comments