Skip to content

Commit e457a74

Browse files
author
yusuke-a-uchida
committed
test unet model
1 parent 66550b9 commit e457a74

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

test_model.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from pathlib import Path
44
import cv2
5-
from model import get_srresnet_model
5+
from model import get_model
66
from noise_model import get_noise_model
77

88

@@ -11,6 +11,8 @@ def get_args():
1111
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
1212
parser.add_argument("--image_dir", type=str, required=True,
1313
help="test image dir")
14+
parser.add_argument("--model", type=str, default="srresnet",
15+
help="model architecture ('srresnet' or 'unet')")
1416
parser.add_argument("--weight_file", type=str, required=True,
1517
help="trained weight file")
1618
parser.add_argument("--test_noise_model", type=str, default="gaussian,25,25",
@@ -31,7 +33,7 @@ def main():
3133
image_dir = args.image_dir
3234
weight_file = args.weight_file
3335
val_noise_model = get_noise_model(args.test_noise_model)
34-
model = get_srresnet_model()
36+
model = get_model(args.model)
3537
model.load_weights(weight_file)
3638

3739
if args.output_dir:
@@ -43,6 +45,9 @@ def main():
4345
for image_path in image_paths:
4446
image = cv2.imread(str(image_path))
4547
h, w, _ = image.shape
48+
image = image[:(h // 16) * 16, :(w // 16) * 16] # for stride (maximum 16)
49+
h, w, _ = image.shape
50+
4651
out_image = np.zeros((h, w * 3, 3), dtype=np.uint8)
4752
noise_image = val_noise_model(image)
4853
pred = model.predict(np.expand_dims(noise_image, 0))

0 commit comments

Comments
 (0)