-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpsnr.py
92 lines (69 loc) · 2.2 KB
/
psnr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/python3
"""
This is an example script that parses the input given to a perceptual metric
part of the 2021 CLIC, perceptual track.
The input is a CSV file with 3 columns. Each column contains the path to a PNG
file. The first is the "original", then "method_a" and "method_b".
This binary generates (on STDOUT, meaning you have to redirect its output to a
file) a CSV containing the same first 3 columns as in the input, with a 4th
column. This column will have the value of:
0 if PSNR(original, method_a) > PSNR(original, method_b)
1 otherwise.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
import cv2
import concurrent.futures
import numpy as np
from absl import app
from absl import flags
from absl import logging
FLAGS = flags.FLAGS
flags.DEFINE_string('input_file', None,
"""Eval CSV file. This has triplets.""")
def load_image(fname):
im = cv2.imread(fname)
return np.asarray(im)
def read_csv(file_name):
"""Read CSV file.
The CSV file contains 3 columns:
OriginalFile,FileA,FileB
OriginalFile: path to the original (uncompressed) image filed.
FileA/FileB: paths to images generated by the two methods will be compared.
Args:
file_name: file name to read.
Returns:
dict({a/b/c} -> score).
"""
contents = []
with open(file_name) as csvfile:
reader = csv.reader(csvfile)
for row in reader:
contents.append(row)
return contents
def psnr(a, b):
mse = np.mean((a.flatten() - b.flatten()) ** 2)
if mse < 1e-6:
return 108.14
return 10 * np.log10(255**2 / mse)
def process_triplet(item):
o = item[0]
a = item[1]
b = item[2]
oi = load_image(o).astype(np.float32)
ai = load_image(a).astype(np.float32)
bi = load_image(b).astype(np.float32)
apsnr = psnr(oi, ai)
bpsnr = psnr(oi, bi)
return o,a,b,1 if apsnr < bpsnr else 0
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
inputs = read_csv(FLAGS.input_file)
with concurrent.futures.ThreadPoolExecutor() as exector:
for (o, a, b, s) in exector.map(process_triplet, inputs):
print('{},{},{},{}'.format(o,a,b,s))
if __name__ == '__main__':
app.run(main)