-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
101 lines (77 loc) · 2.9 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#! /usr/bin/env python
# coding=utf-8
# /************************************************************************************
# ***
# *** File Author: Dell, Thu Sep 20 21:42:14 CST 2018
# ***
# ************************************************************************************/
import argparse
import os
from os.path import basename
from os.path import splitext
import logging
import torch
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import model
parser = argparse.ArgumentParser()
parser.add_argument('-style', type=str, help='File path to the style image')
parser.add_argument(
'-content', type=str, help='File path to the content image')
parser.add_argument(
'-output',
type=str,
default='output',
help='Directory to save the output images')
parser.add_argument('-e', '--encoder', type=str, default='models/encoder.pth')
parser.add_argument('-d', '--decoder', type=str, default='models/decoder.pth')
parser.add_argument(
'-alpha',
type=float,
default=1.0,
help='The weight that controls the degree of stylization. Should be [0, 1]'
)
if __name__ == '__main__':
args = parser.parse_args()
assert (args.content and args.style)
if os.path.isdir(args.style):
style_paths = [os.path.join(args.style, f) for f in os.listdir(args.style)]
else:
style_paths = [args.style]
if os.path.isdir(args.content):
content_paths = [
os.path.join(args.content, f) for f in os.listdir(args.content)
]
else:
content_paths = [args.content]
if not os.path.exists(args.output):
os.mkdir(args.output)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
encoder = model.encoder_load(args.encoder)
encoder.eval()
encoder.to(device)
decoder = model.decoder_load(args.decoder)
decoder.eval()
decoder.to(device)
T = transforms.Compose([
transforms.ToTensor(),
])
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
for content_path in content_paths:
for style_path in style_paths:
logging.info("Transfering from " + style_path + " to " + content_path +
" ...")
content = T(Image.open(content_path).convert('RGB'))
content = content.to(device).unsqueeze(0)
style = T(Image.open(style_path).convert('RGB'))
style = style.to(device).unsqueeze(0)
with torch.no_grad():
output = model.style_transfer(encoder, decoder, content, style,
args.alpha)
output = output.cpu()
output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
args.output,
splitext(basename(content_path))[0],
splitext(basename(style_path))[0], ".jpg")
save_image(output, output_name)