-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnet.py
220 lines (164 loc) · 8.27 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import numpy as np
import os
import torch
import torchvision.models as models
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import sys
import math
import torch.nn.init as init
import logging
from torch.nn.parameter import Parameter
from subnet import *
import torchac
def save_model(model, iter):
torch.save(model.state_dict(), "autodl-tmp/snapshot/iter{}.model".format(iter))
def load_model(model, f):
with open(f, 'rb') as f:
pretrained_dict = torch.load(f)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
f = str(f)
if f.find('iter') != -1 and f.find('.model') != -1:
st = f.find('iter') + 4
ed = f.find('.model', st)
return int(f[st:ed])
else:
return 0
class VideoCompressor(nn.Module):
def __init__(self):
super(VideoCompressor, self).__init__()
# self.imageCompressor = ImageCompressor()
self.opticFlow =LiteFlowNet3()#ME_Spynet()
self.mvEncoder = Analysis_mv_net()
self.Q = None
self.mvDecoder = Synthesis_mv_net()
self.warpnet = Warp_net()
self.resEncoder = Analysis_net()
self.resDecoder = Synthesis_net()
self.respriorEncoder = Analysis_prior_net()
self.respriorDecoder = Synthesis_prior_net()
self.bitEstimator_z = BitEstimator(out_channel_N)
self.bitEstimator_mv = BitEstimator(out_channel_mv)
# self.flow_warp = Resample2d()
# self.bitEstimator_feature = BitEstimator(out_channel_M)
self.warp_weight = 0
self.mxrange = 150
self.calrealbits = False
def forwardFirstFrame(self, x):
output, bittrans = self.imageCompressor(x)
cost = self.bitEstimator(bittrans)
return output, cost
def motioncompensation(self, ref, mv):
warpframe = flow_warp(ref, mv)
inputfeature = torch.cat((warpframe, ref), 1)
prediction = self.warpnet(inputfeature) + warpframe
return prediction, warpframe
def forward(self, input_image, referframe, quant_noise_feature=None, quant_noise_z=None, quant_noise_mv=None):
estmv = self.opticFlow(input_image, referframe)
estmv=torch.nn.functional.interpolate(estmv, size=(input_image.shape[2],input_image.shape[3]), mode='bilinear', align_corners=False).cuda()
mvfeature = self.mvEncoder(estmv)
if self.training:
quant_mv = mvfeature + quant_noise_mv
else:
quant_mv = torch.round(mvfeature)
quant_mv_upsample = self.mvDecoder(quant_mv)
prediction, warpframe = self.motioncompensation(referframe, quant_mv_upsample)
input_residual = input_image - prediction
feature = self.resEncoder(input_residual)
batch_size = feature.size()[0]
z = self.respriorEncoder(feature)
if self.training:
compressed_z = z + quant_noise_z
else:
compressed_z = torch.round(z)
recon_sigma = self.respriorDecoder(compressed_z)
feature_renorm = feature
if self.training:
compressed_feature_renorm = feature_renorm + quant_noise_feature
else:
compressed_feature_renorm = torch.round(feature_renorm)
recon_res = self.resDecoder(compressed_feature_renorm)
recon_image = prediction + recon_res
clipped_recon_image = recon_image.clamp(0., 1.)
# distortion
mse_loss = torch.mean((recon_image - input_image).pow(2))
# psnr = tf.cond(
# tf.equal(mse_loss, 0), lambda: tf.constant(100, dtype=tf.float32),
# lambda: 10 * (tf.log(1 * 1 / mse_loss) / np.log(10)))
warploss = torch.mean((warpframe - input_image).pow(2))
interloss = torch.mean((prediction - input_image).pow(2))
# bit per pixel
def feature_probs_based_sigma(feature, sigma):
def getrealbitsg(x, gaussian):
# print("NIPS18noc : mn : ", torch.min(x), " - mx : ", torch.max(x), " range : ", self.mxrange)
cdfs = []
x = x + self.mxrange
n,c,h,w = x.shape
for i in range(-self.mxrange, self.mxrange):
cdfs.append(gaussian.cdf(i - 0.5).view(n,c,h,w,1))
cdfs = torch.cat(cdfs, 4).cpu().detach()
byte_stream = torchac.encode_float_cdf(cdfs, x.cpu().detach().to(torch.int16), check_input_bounds=True)
real_bits = torch.from_numpy(np.array([len(byte_stream) * 8])).float().cuda()
sym_out = torchac.decode_float_cdf(cdfs, byte_stream)
return sym_out - self.mxrange, real_bits
mu = torch.zeros_like(sigma)
sigma = sigma.clamp(1e-5, 1e10)
gaussian = torch.distributions.laplace.Laplace(mu, sigma)
probs = gaussian.cdf(feature + 0.5) - gaussian.cdf(feature - 0.5)
total_bits = torch.sum(torch.clamp(-1.0 * torch.log(probs + 1e-5) / math.log(2.0), 0, 50))
if self.calrealbits and not self.training:
decodedx, real_bits = getrealbitsg(feature, gaussian)
total_bits = real_bits
return total_bits, probs
def iclr18_estrate_bits_z(z):
def getrealbits(x):
cdfs = []
x = x + self.mxrange
n,c,h,w = x.shape
for i in range(-self.mxrange, self.mxrange):
cdfs.append(self.bitEstimator_z(i - 0.5).view(1, c, 1, 1, 1).repeat(1, 1, h, w, 1))
cdfs = torch.cat(cdfs, 4).cpu().detach()
byte_stream = torchac.encode_float_cdf(cdfs, x.cpu().detach().to(torch.int16), check_input_bounds=True)
real_bits = torch.sum(torch.from_numpy(np.array([len(byte_stream) * 8])).float().cuda())
sym_out = torchac.decode_float_cdf(cdfs, byte_stream)
return sym_out - self.mxrange, real_bits
prob = self.bitEstimator_z(z + 0.5) - self.bitEstimator_z(z - 0.5)
total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-5) / math.log(2.0), 0, 50))
if self.calrealbits and not self.training:
decodedx, real_bits = getrealbits(z)
total_bits = real_bits
return total_bits, prob
def iclr18_estrate_bits_mv(mv):
def getrealbits(x):
cdfs = []
x = x + self.mxrange
n,c,h,w = x.shape
for i in range(-self.mxrange, self.mxrange):
cdfs.append(self.bitEstimator_mv(i - 0.5).view(1, c, 1, 1, 1).repeat(1, 1, h, w, 1))
cdfs = torch.cat(cdfs, 4).cpu().detach()
byte_stream = torchac.encode_float_cdf(cdfs, x.cpu().detach().to(torch.int16), check_input_bounds=True)
real_bits = torch.sum(torch.from_numpy(np.array([len(byte_stream) * 8])).float().cuda())
sym_out = torchac.decode_float_cdf(cdfs, byte_stream)
return sym_out - self.mxrange, real_bits
prob = self.bitEstimator_mv(mv + 0.5) - self.bitEstimator_mv(mv - 0.5)
total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-5) / math.log(2.0), 0, 50))
if self.calrealbits and not self.training:
decodedx, real_bits = getrealbits(mv)
total_bits = real_bits
return total_bits, prob
total_bits_feature, _ = feature_probs_based_sigma(compressed_feature_renorm, recon_sigma)
# entropy_context = entropy_context_from_sigma(compressed_feature_renorm, recon_sigma)
total_bits_z, _ = iclr18_estrate_bits_z(compressed_z)
total_bits_mv, _ = iclr18_estrate_bits_mv(quant_mv)
im_shape = input_image.size()
bpp_feature = total_bits_feature / (batch_size * im_shape[2] * im_shape[3])
bpp_z = total_bits_z / (batch_size * im_shape[2] * im_shape[3])
bpp_mv = total_bits_mv / (batch_size * im_shape[2] * im_shape[3])
bpp = bpp_feature + bpp_z + bpp_mv
return clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp