Skip to content

Commit 22da5db

Browse files
authoredDec 2, 2021
Add files via upload
1 parent 178b0ef commit 22da5db

File tree

16 files changed

+6074
-0
lines changed

16 files changed

+6074
-0
lines changed
 

‎gcn_layers.py

+396
Large diffs are not rendered by default.

‎image_inpaint.py

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import os
2+
import cv2
3+
import numpy as np
4+
# import random
5+
import imutils
6+
import copy
7+
# import tensorflow.compat.v1 as tf
8+
# np.random.seed(0)
9+
10+
# #########
11+
# def _bytes_feature(value):
12+
# return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
13+
# def gen_tf_ex(imgs,segs,ex_name,gt_valid_id,writer,size,img_num):
14+
# W,H = size
15+
16+
# concat_view = np.concatenate(imgs,axis=1)
17+
# if segs is not None:
18+
# concat_seg = np.concatenate(segs,axis=1)
19+
# else:
20+
# concat_seg = np.zeros((H,W*img_num,1))
21+
22+
# # gt_valid_id = 0
23+
# good_seg = np.zeros((H,W,1))
24+
25+
# h, w, c = concat_view.shape
26+
# if not (h==H and w==W*img_num and c==3):
27+
# print('STOP:',concat_view.shape,concat_depth.shape)
28+
# return writer
29+
30+
# # CONVERT THE VARIABLES TO THE TARGET TYPE: IMPORTANT!!!
31+
# concat_view = concat_view.astype(np.uint8)
32+
# good_seg = good_seg.astype(np.float32)
33+
# concat_seg = concat_seg.astype(np.float32)
34+
# gt_valid_id = np.array([gt_valid_id]).astype(np.int32)
35+
36+
# bbox_seq = np.zeros(img_num*4)
37+
# # print(concat_seg.shape,concat_view.shape)
38+
39+
# example = tf.train.Example(features=tf.train.Features(feature={
40+
# 'image_seq': _bytes_feature(concat_view.tostring()),
41+
# 'good_seg': _bytes_feature(good_seg.tostring()),
42+
# 'seg_seq': _bytes_feature(concat_seg.tostring()),
43+
# 'gt_valid_id': _bytes_feature(gt_valid_id.tostring()),
44+
# 'bbox_seq': _bytes_feature(bbox_seq.tostring()),
45+
# # 'bbox_segs': _bytes_feature(bbox_segs.tostring()),
46+
# # 'edge_seq': _bytes_feature(edge_seq.tostring()),
47+
# 'seq_name': _bytes_feature(str.encode(ex_name)),
48+
# }))
49+
50+
# writer.write(example.SerializeToString())
51+
# return writer
52+
# #########
53+
54+
def distance(p1,p2):
55+
p1 = np.squeeze(p1)
56+
p2 = np.squeeze(p2)
57+
58+
diff = p1-p2
59+
if len(diff.shape) == 2:
60+
return np.linalg.norm(diff,axis=1)
61+
else:
62+
return np.linalg.norm(diff)
63+
64+
def ext_endo_pos(img):
65+
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
66+
67+
endo_pos = cv2.HoughCircles(img,cv2.HOUGH_GRADIENT,1,10,param1=50,param2=30,minRadius=100,maxRadius=150)
68+
if endo_pos is None:
69+
return (np.nan,np.nan,np.nan)
70+
else:
71+
endo_pos = endo_pos[0][0]
72+
x,y,r = endo_pos
73+
return (x,y,r)
74+
75+
def genMask(endo_pos,img):
76+
cx, cy, r = endo_pos
77+
mask = np.zeros_like(img)
78+
if np.isnan(cx):
79+
return 1-mask
80+
cv2.circle(mask,(int(cx),int(cy)), int(r), (1,1,1), -1)
81+
return mask
82+
83+
def shift_image(img,dx,dy):
84+
rows, cols, _ = img.shape
85+
M = np.float32([[1,0,dx],[0,1,dy]])
86+
dst = cv2.warpAffine(img,M,(cols,rows))
87+
return dst
88+
89+
#####
90+
91+
def postproc_image(image,resize_wh=None, nearest_interpolate=True):
92+
if resize_wh is not None:
93+
if nearest_interpolate:
94+
image = cv2.resize(image,resize_wh,interpolation=cv2.INTER_NEAREST)
95+
else:
96+
image = cv2.resize(image,resize_wh)
97+
return image
98+
99+
# def inpaint_image(imagefile, labelfile, inpaint_dir, seg_length, resize_wh=None, angle_range=(30,60),shift_range=(10,30,30,80),bg_max_shift=40):
100+
def inpaint_image(imagefile, labelfile, inpaint_dir, seg_length, resize_wh=None, angle_range=(30,40),shift_range=(10,30,20,60),bg_max_shift=40):
101+
if 'EP' in os.path.basename(imagefile):
102+
angle_range=(30,40)
103+
shift_range=(10,30,20,60)
104+
bg_max_shift=40
105+
106+
sample_num = seg_length//2
107+
# sample_num = 0 # $$$$$$$$$
108+
# shuffle_idx = np.random.choice(seg_length,seg_length,replace=False)
109+
110+
image = cv2.imread(imagefile)
111+
112+
basename = os.path.basename(imagefile)[:-4]
113+
inpaint_img = cv2.imread(os.path.join(inpaint_dir,basename+"_bg.jpg"))
114+
height,width,_ = inpaint_img.shape
115+
width = int(width/2)
116+
inpaint_img,edge_mask = inpaint_img[:,:width],inpaint_img[:,width:]
117+
118+
# shift bg
119+
_bg_seq = []
120+
# generate bw and fw translation (max: 0~20 pixels)
121+
max_shift = np.random.random()*bg_max_shift
122+
dx,dy = np.random.random(2)*max_shift-max_shift/2
123+
dx_list = np.linspace(-dx,dx,seg_length)
124+
dy_list = np.linspace(-dy,dy,seg_length)
125+
for dx,dy in zip(dx_list,dy_list):
126+
shift_bg = shift_image(inpaint_img,dx,dy)
127+
_bg_seq.append(shift_bg*edge_mask)
128+
# # randomly shuffle bg seq
129+
# shuffle_idx = np.random.choice(seg_length,seg_length,replace=False)
130+
# _bg_seq = [_bg_seq[i] for i in shuffle_idx]
131+
132+
# cv2.imwrite(str(file_id)+"_bg_seq.jpg",np.concatenate(_bg_seq,axis=1))
133+
134+
new_label = cv2.imread(os.path.join(inpaint_dir,basename+"_label.png"))
135+
inpaint_inst = cv2.imread(os.path.join(inpaint_dir,basename+"_inst.jpg"))
136+
137+
if new_label is None:
138+
# read the label image
139+
label = cv2.imread(labelfile)
140+
if np.sum(label)>0:
141+
# print("## bad image ##",os.path.basename(imagefile))
142+
return None, None, None
143+
144+
_bg_seq[sample_num] = image*edge_mask
145+
146+
_label_seq = [np.zeros((height,width,1))]*seg_length
147+
#[np.zeros((height,width,1)) for i in range(seg_length)]
148+
_new_label = label*edge_mask
149+
_label_seq[sample_num] = _new_label[:,:,0][...,np.newaxis]
150+
151+
_bg_seq = [postproc_image(_bg_seq[i],resize_wh=resize_wh) for i in range(seg_length)]
152+
_label_seq = [postproc_image(_label_seq[i],resize_wh=resize_wh) for i in range(seg_length)]
153+
# _bg_seq = [postproc_image(_bg_seq[i],resize_wh=resize_wh) for i in shuffle_idx]
154+
# _label_seq = [postproc_image(_label_seq[i],resize_wh=resize_wh) for i in shuffle_idx]
155+
# gt_valid_id = np.squeeze(np.argwhere(np.array(shuffle_idx)==sample_num))
156+
gt_valid_id = sample_num
157+
return _bg_seq, _label_seq, gt_valid_id
158+
159+
cv2.imwrite(os.path.join("./samples",basename+".png"),new_label/np.max(new_label)*255)
160+
161+
if 'EP' in os.path.basename(imagefile):
162+
new_width = width*1.2
163+
new_height = height*1.2
164+
else:
165+
new_width = width*2
166+
new_height = height*2
167+
# randomly rotate and shift instrument
168+
if np.random.random() < 0.0:
169+
_angle_range = angle_range[0]
170+
_shift_min, _shift_max = shift_range[0], shift_range[1]
171+
else:
172+
_angle_range = angle_range[1]
173+
_shift_min, _shift_max = shift_range[2], shift_range[3]
174+
175+
# randomly shift every instruments
176+
new_label_all = copy.deepcopy(new_label)
177+
inst_ids = np.unique(new_label_all[new_label_all>0])
178+
_img_seq = copy.deepcopy(_bg_seq)
179+
_label_seq = [np.zeros((height,width))]*seg_length
180+
for inst_id in inst_ids:
181+
new_label = np.zeros_like(new_label_all)
182+
new_label[new_label_all==inst_id] = 1
183+
184+
# generate bw and fw angle (max: ranges -40~40 degree)
185+
max_degree = np.random.random()*_angle_range-_angle_range/2
186+
ang_bw = np.random.random()*max_degree-max_degree
187+
ang_fw = np.random.random()*max_degree
188+
ang_list = list(np.linspace(ang_bw,0,sample_num+1)[:-1])+list(np.linspace(0,ang_fw,sample_num+1))
189+
# generate bw and fw translation (max: 10~50 pixels)
190+
max_shift = np.random.random()*(_shift_max-_shift_min)+_shift_min
191+
dx,dy = np.random.random(2)*max_shift-max_shift/2
192+
dx_list = np.linspace(-dx,dx,seg_length)
193+
dy_list = np.linspace(-dy,dy,seg_length)
194+
#
195+
196+
_frame_id = 0
197+
for angle,dx,dy,bg_img in zip(ang_list,dx_list,dy_list,_img_seq):
198+
199+
_inpaint_inst = imutils.rotate(inpaint_inst, angle)
200+
_inpaint_inst = shift_image(_inpaint_inst,dx,dy)
201+
# cv2.imwrite("inpaint_inst.jpg",_inpaint_inst)
202+
203+
_new_label = imutils.rotate(new_label, angle)
204+
_new_label = shift_image(_new_label,dx,dy)
205+
label_dx,label_dy = int((new_width-width)/2),int((new_height-height)/2)
206+
_new_label = _new_label[label_dy:label_dy+height,label_dx:label_dx+width]
207+
_new_label = _new_label*edge_mask
208+
# cv2.imwrite("new_label.jpg",_new_label*255)
209+
210+
# if np.sum(_new_label[:,:,0]) < 8000 and np.random.random() < 0.4 and not _frame_id == sample_num:
211+
# _new_label = np.zeros_like(_new_label)
212+
# _inpaint_img = bg_img
213+
# else:
214+
# _inpaint_inst = _inpaint_inst[label_dy:label_dy+height,label_dx:label_dx+width]
215+
# _smooth_new_label = cv2.GaussianBlur(_new_label.astype(np.float32),(3,3),3)
216+
# _inpaint_img = bg_img * (1-_smooth_new_label) + _inpaint_inst * _smooth_new_label
217+
_inpaint_inst = _inpaint_inst[label_dy:label_dy+height,label_dx:label_dx+width]
218+
_smooth_new_label = cv2.GaussianBlur(_new_label.astype(np.float32),(3,3),3)
219+
_inpaint_img = bg_img * (1-_smooth_new_label) + _inpaint_inst * _smooth_new_label
220+
# sample = np.concatenate([image,inpaint_img],axis=1)
221+
# cv2.imwrite("inpaint_results.jpg",sample)
222+
223+
if _frame_id == sample_num:
224+
_inpaint_img = image*edge_mask
225+
_new_label = cv2.imread(labelfile)*edge_mask
226+
# _new_label = _new_label*edge_mask
227+
228+
if np.random.random() < 0.6:
229+
_inpaint_img = np.clip(_inpaint_img + np.random.randint(10,60)*1.0,0.,255.)
230+
_inpaint_img = _inpaint_img*edge_mask
231+
_img_seq[_frame_id] = _inpaint_img#*final_edge_mask
232+
233+
_new_label = _new_label[:,:,0]
234+
_new_label = _new_label+_label_seq[_frame_id]
235+
_new_label[_new_label>0]=1
236+
_label_seq[_frame_id] = _new_label#*final_edge_mask[:,:,0][...,np.newaxis]
237+
238+
# if np.sum((np.mean(_inpaint_img,axis=2)*_new_label[:,:,0])>240)/(np.sum(_new_label[:,:,0])+0.00000001) >= 0.6:
239+
# save_seq += 1
240+
_frame_id+=1
241+
242+
# randomly shuffle img and seg seq
243+
_img_seq = [postproc_image(_img_seq[i],resize_wh=resize_wh) for i in range(seg_length)]
244+
_label_seq = [postproc_image(_label_seq[i][...,np.newaxis],resize_wh=resize_wh) for i in range(seg_length)]
245+
# _img_seq = [postproc_image(_img_seq[i],resize_wh=resize_wh) for i in shuffle_idx]
246+
# _label_seq = [postproc_image(_label_seq[i],resize_wh=resize_wh) for i in shuffle_idx]
247+
# gt_valid_id = np.squeeze(np.argwhere(np.array(shuffle_idx)==sample_num))
248+
gt_valid_id = sample_num
249+
return _img_seq, _label_seq, gt_valid_id

‎labels.json

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[{
2+
"name": "background-tissue",
3+
"color": [
4+
0,
5+
0,
6+
0
7+
],
8+
"classid": 0
9+
},
10+
{
11+
"name": "instrument",
12+
"color": [
13+
0,
14+
255,
15+
0
16+
],
17+
"classid": 1
18+
}
19+
]

‎mobilenetv1.py

+445
Large diffs are not rendered by default.

‎model.py

+2,107
Large diffs are not rendered by default.

‎ops.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import math
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
from tensorflow.python.framework import ops
6+
7+
from utils import *
8+
9+
10+
image_summary = tf.summary.image
11+
scalar_summary = tf.summary.scalar
12+
histogram_summary = tf.summary.histogram
13+
merge_summary = tf.summary.merge
14+
SummaryWriter = tf.summary.FileWriter
15+
seed = 23
16+
17+
def batchnorm(input_,is_train=False,name="batchnorm"):
18+
with tf.variable_scope(name):
19+
normalized = tf.layers.batch_normalization(input_, training=is_train)
20+
return normalized
21+
22+
def conv2d(input_, output_dim, ksize=3, stride=2, stddev=0.02,name="conv2d"):
23+
with tf.variable_scope(name):
24+
w = tf.get_variable('w', [ksize, ksize, input_.get_shape()[-1], output_dim],
25+
initializer=tf.truncated_normal_initializer(stddev=stddev, seed=seed))
26+
27+
conv = tf.nn.conv2d(input_, w, strides=[1, stride, stride, 1], padding='SAME')
28+
29+
biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
30+
# conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
31+
conv = tf.nn.bias_add(conv, biases)
32+
33+
return conv
34+
35+
def conv2d_dilated(input_, output_dim, ksize=3, rate=2, stddev=0.02,name="conv2d_dilated"):
36+
with tf.variable_scope(name):
37+
w = tf.get_variable('w', [ksize, ksize, input_.get_shape()[-1], output_dim],
38+
initializer=tf.truncated_normal_initializer(stddev=stddev, seed=seed))
39+
40+
conv = tf.nn.atrous_conv2d(input_,w,rate=rate,padding="SAME")
41+
42+
43+
biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
44+
# conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
45+
conv = tf.nn.bias_add(conv, biases)
46+
47+
return conv
48+
49+
def deconv2d(input_, output_shape,
50+
ksize=5, stride=2, stddev=0.02,
51+
name="deconv2d", with_w=False):
52+
with tf.variable_scope(name):
53+
# filter : [height, width, output_channels, in_channels]
54+
w = tf.get_variable('w', [ksize, ksize, output_shape[-1], input_.get_shape()[-1]],
55+
initializer=tf.truncated_normal_initializer(stddev=stddev, seed=seed))
56+
57+
try:
58+
deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
59+
strides=[1, stride, stride, 1])
60+
61+
# Support for verisons of TensorFlow before 0.7.0
62+
except AttributeError:
63+
deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
64+
strides=[1, stride, stride, 1])
65+
66+
biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
67+
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
68+
69+
if with_w:
70+
return deconv, w, biases
71+
else:
72+
return deconv
73+
74+
def lrelu(x, leak=0.2, name="lrelu"):
75+
return tf.maximum(x, leak*x)
76+
77+
def prelu(x, name="prelu"):
78+
with tf.variable_scope(name):
79+
alpha = tf.get_variable("prelu", shape=x.get_shape()[-1], initializer=tf.constant_initializer(0.0))
80+
return tf.maximum(0.0, x) + alpha * tf.minimum(0.0, x)
81+
82+
83+
def relu(x, name="relu"):
84+
return tf.maximum(x, 0)
85+
86+
def separable_conv2d(input_, output_dim, ksize=3, stride=1,rate=1, stddev=0.02,name=''):
87+
with tf.variable_scope(name+"_separable_conv2d"):
88+
in_chns = input_.get_shape()[3].value
89+
w_depth = tf.get_variable('w_depth', [ksize,ksize,in_chns,1],initializer=tf.truncated_normal_initializer(stddev=stddev, seed=seed))
90+
w_point = tf.get_variable('w_point', [1,1,in_chns,output_dim],initializer=tf.truncated_normal_initializer(stddev=stddev, seed=seed))
91+
conv = tf.nn.separable_conv2d( input_,
92+
depthwise_filter = w_depth,
93+
pointwise_filter = w_point,
94+
strides = [1,stride,stride,1],
95+
padding="SAME",
96+
rate=[rate,rate],
97+
name="sep_conv")
98+
biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
99+
output = tf.nn.bias_add(conv, biases)
100+
101+
return output
102+
103+
# do_batchnorm=True
104+
def atrous_spatial_pyramid_pooling(input_, output_stride=16, depth=256,is_train=False,dropout=False,keep_prob=1.0):
105+
"""Atrous Spatial Pyramid Pooling.
106+
Args:
107+
inputs: A tensor of size [batch, height, width, channels].
108+
output_stride: The ResNet unit's stride. Determines the rates for atrous convolution.
109+
the rates are (6, 12, 18) when the stride is 16, and doubled when 8.
110+
111+
depth: The depth of the ResNet unit output.
112+
Returns:
113+
The atrous spatial pyramid pooling output.
114+
"""
115+
with tf.variable_scope("aspp"):
116+
117+
atrous_rates = [2,4]#[6, 12, 18]
118+
if output_stride == 8:
119+
atrous_rates = [2*rate for rate in atrous_rates]
120+
121+
# (a) one 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18) when output stride = 16.
122+
# the rates are doubled when output stride = 8.
123+
h1 = conv2d(input_, depth, ksize=1, stride=1, name="conv1")
124+
h1 = tf.nn.relu(batchnorm(h1,is_train,'bn1'))
125+
# if do_batchnorm:
126+
# h1 = tf.nn.relu(batchnorm(h1,is_train,'bn1'))
127+
# else:
128+
# h1 = tf.nn.relu(h1)
129+
130+
h2 = conv2d_dilated(input_, depth, ksize=3,rate=atrous_rates[0], name="conv3_1")
131+
h2 = tf.nn.relu(batchnorm(h2,is_train,'bn2'))
132+
# if do_batchnorm:
133+
# h2 = tf.nn.relu(batchnorm(h2,is_train,'bn2'))
134+
# else:
135+
# h2 = tf.nn.relu(h2)
136+
137+
h3 = conv2d_dilated(input_, depth, ksize=3,rate=atrous_rates[1], name="conv3_2")
138+
h3 = tf.nn.relu(batchnorm(h3,is_train,'bn3'))
139+
# if do_batchnorm:
140+
# h3 = tf.nn.relu(batchnorm(h3,is_train,'bn3'))
141+
# else:
142+
# h3 = tf.nn.relu(h3)
143+
144+
# (b) the image-level features
145+
input_size = tf.shape(input_)[1:3]
146+
h0 = tf.reduce_mean(input_, [1, 2], name='global_average_pooling', keepdims=True)
147+
h0 = conv2d(h0, depth, ksize=1, stride=1, name="conv1_pool")
148+
h0 = tf.nn.relu(batchnorm(h0,is_train,'bn_gap'))
149+
# if do_batchnorm:
150+
# h0 = tf.nn.relu(batchnorm(h0,is_train,'bn_gap'))
151+
# else:
152+
# h0 = tf.nn.relu(h0)
153+
h0 = tf.image.resize_bilinear(h0, input_size, name='upsample')
154+
155+
156+
h = tf.concat([h0,h1,h2,h3],axis=3)
157+
158+
h = conv2d(h, depth, ksize=1, stride=1, name="conv1_out")
159+
h = tf.nn.relu(batchnorm(h,is_train,'bn_out'))
160+
# if do_batchnorm:
161+
# h = tf.nn.relu(batchnorm(h,is_train,'bn_out'))
162+
# else:
163+
# h = tf.nn.relu(h)
164+
165+
if dropout:
166+
h = tf.nn.dropout(h,keep_prob,seed=seed)
167+
168+
return h
169+
170+

‎resnetv2.py

+513
Large diffs are not rendered by default.

‎resources/env.yml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# run: conda env create -f resources/env.yml
2+
# to update existing environment: conda env update -f resources/env.yml
3+
name: mffa
4+
channels:
5+
- conda-forge
6+
dependencies:
7+
- tensorflow-gpu=1.13
8+
- cudatoolkit
9+
- opencv
10+
- imutils

‎rnn_cell.py

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import tensorflow as tf
2+
3+
class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
4+
"""A LSTM cell with convolutions instead of multiplications.
5+
6+
Reference:
7+
Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015.
8+
"""
9+
10+
def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=True, peephole=True, data_format='channels_last', reuse=None):
11+
super(ConvLSTMCell, self).__init__(_reuse=reuse)
12+
self._kernel = kernel
13+
self._filters = filters
14+
self._forget_bias = forget_bias
15+
self._activation = activation
16+
self._normalize = normalize
17+
self._peephole = peephole
18+
if data_format == 'channels_last':
19+
self._size = tf.TensorShape(shape + [self._filters])
20+
self._feature_axis = self._size.ndims
21+
self._data_format = None
22+
elif data_format == 'channels_first':
23+
self._size = tf.TensorShape([self._filters] + shape)
24+
self._feature_axis = 0
25+
self._data_format = 'NC'
26+
else:
27+
raise ValueError('Unknown data_format')
28+
29+
@property
30+
def state_size(self):
31+
return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size)
32+
33+
@property
34+
def output_size(self):
35+
return self._size
36+
37+
def call(self, x, state):
38+
c, h = state
39+
40+
x = tf.concat([x, h], axis=self._feature_axis)
41+
n = x.shape[-1].value
42+
m = 4 * self._filters if self._filters > 1 else 4
43+
W = tf.get_variable('kernel', self._kernel + [n, m])
44+
y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format)
45+
if not self._normalize:
46+
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
47+
j, i, f, o = tf.split(y, 4, axis=self._feature_axis)
48+
49+
if self._peephole:
50+
i += tf.get_variable('W_ci', c.shape[1:]) * c
51+
f += tf.get_variable('W_cf', c.shape[1:]) * c
52+
53+
if self._normalize:
54+
j = tf.contrib.layers.layer_norm(j)
55+
i = tf.contrib.layers.layer_norm(i)
56+
f = tf.contrib.layers.layer_norm(f)
57+
58+
f = tf.sigmoid(f + self._forget_bias)
59+
i = tf.sigmoid(i)
60+
c = c * f + i * self._activation(j)
61+
62+
if self._peephole:
63+
o += tf.get_variable('W_co', c.shape[1:]) * c
64+
65+
if self._normalize:
66+
o = tf.contrib.layers.layer_norm(o)
67+
c = tf.contrib.layers.layer_norm(c)
68+
69+
o = tf.sigmoid(o)
70+
h = o * self._activation(c)
71+
72+
state = tf.nn.rnn_cell.LSTMStateTuple(c, h)
73+
74+
return h, state
75+
76+
77+
class ConvGRUCell(tf.nn.rnn_cell.RNNCell):
78+
"""A GRU cell with convolutions instead of multiplications."""
79+
80+
def __init__(self, shape, filters, kernel, activation=tf.tanh, normalize=True, data_format='channels_last', reuse=None):
81+
super(ConvGRUCell, self).__init__(_reuse=reuse)
82+
self._filters = filters
83+
self._kernel = kernel
84+
self._activation = activation
85+
self._normalize = normalize
86+
if data_format == 'channels_last':
87+
self._size = tf.TensorShape(shape + [self._filters])
88+
self._feature_axis = self._size.ndims
89+
self._data_format = None
90+
elif data_format == 'channels_first':
91+
self._size = tf.TensorShape([self._filters] + shape)
92+
self._feature_axis = 0
93+
self._data_format = 'NC'
94+
else:
95+
raise ValueError('Unknown data_format')
96+
97+
@property
98+
def state_size(self):
99+
return self._size
100+
101+
@property
102+
def output_size(self):
103+
return self._size
104+
105+
def call(self, x, h):
106+
channels = x.shape[self._feature_axis].value
107+
108+
with tf.variable_scope('gates'):
109+
inputs = tf.concat([x, h], axis=self._feature_axis)
110+
n = channels + self._filters
111+
m = 2 * self._filters if self._filters > 1 else 2
112+
W = tf.get_variable('kernel', self._kernel + [n, m])
113+
y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
114+
if self._normalize:
115+
r, u = tf.split(y, 2, axis=self._feature_axis)
116+
r = tf.contrib.layers.layer_norm(r)
117+
u = tf.contrib.layers.layer_norm(u)
118+
else:
119+
y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())
120+
r, u = tf.split(y, 2, axis=self._feature_axis)
121+
r, u = tf.sigmoid(r), tf.sigmoid(u)
122+
123+
with tf.variable_scope('candidate'):
124+
inputs = tf.concat([x, r * h], axis=self._feature_axis)
125+
n = channels + self._filters
126+
m = self._filters
127+
W = tf.get_variable('kernel', self._kernel + [n, m])
128+
y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
129+
if self._normalize:
130+
y = tf.contrib.layers.layer_norm(y)
131+
else:
132+
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
133+
h = u * h + (1 - u) * self._activation(y)
134+
135+
return h, h

‎test.py

+384
Large diffs are not rendered by default.

‎test.sh

+625
Large diffs are not rendered by default.

‎test_tf.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from tensorflow.python.client import device_lib
2+
print(device_lib.list_local_devices())

‎train.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
import numpy as np
3+
import cv2
4+
from model import *
5+
from utils import *
6+
import tensorflow as tf
7+
8+
overall_random_seed = 23 # EP:5, sinus:23
9+
np.random.seed(overall_random_seed)
10+
tf.set_random_seed(overall_random_seed)
11+
12+
def del_all_flags(FLAGS):
13+
flags_dict = FLAGS._flags()
14+
keys_list = [keys for keys in flags_dict]
15+
for keys in keys_list:
16+
FLAGS.__delattr__(keys)
17+
del_all_flags(tf.flags.FLAGS)
18+
tf.reset_default_graph()
19+
tf.set_random_seed(overall_random_seed)
20+
21+
flags = tf.app.flags
22+
flags.DEFINE_integer("epoch",30, "Epoch to train [25]")
23+
flags.DEFINE_integer("batch_size", 16, "The size of batch images [64]")
24+
flags.DEFINE_integer("seed", overall_random_seed, "random seed")
25+
flags.DEFINE_integer("input_height", 240, "The size of image to use (will be center cropped). [108]")
26+
flags.DEFINE_integer("input_width", 240, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
27+
flags.DEFINE_integer("crop_height", 192, "The size of image to crop")
28+
flags.DEFINE_integer("crop_width", 192, "")
29+
flags.DEFINE_integer("temporal_len",4,"the number of consecutive frames to input")
30+
31+
# flags.DEFINE_string("train_dataset", "../sinus_data/cadaver", "train dataset direction")
32+
flags.DEFINE_string("train_dataset", "../sinus_data/syn_cadaver", "train dataset direction")
33+
flags.DEFINE_string("frame_dataset", "../sinus_data/cadaver/frame_dataset", "frame dataset direction")
34+
flags.DEFINE_string("video_dir", "../sinus_data/cadaver/videos", "train dataset direction")
35+
flags.DEFINE_string("datasets", "cf1cf2", "")
36+
37+
flags.DEFINE_string("img_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
38+
flags.DEFINE_string("label_pattern", "*.png", "Glob pattern of filename of input labels [*]")
39+
40+
flags.DEFINE_string("checkpoint_dir", "./checkpoint", "Directory name to save the checkpoints [checkpoint]")
41+
flags.DEFINE_string("save_checkpoint_dir", "", "Directory name to save the checkpoints [checkpoint]")
42+
# flags.DEFINE_string("pretrain_dir", "../pretrain/resnet_v2_50_2017_04_14", "")
43+
flags.DEFINE_string("pretrain_dir", "../pretrain/mobilenet_v1_1.0_224", "")
44+
45+
#$$$$ SL
46+
flags.DEFINE_string("model_type", "deeplab_mobilenet", "")#unet, deeplab_mobilenet, deeplab_resnet
47+
48+
flags.DEFINE_integer("continue_train",0,"")
49+
flags.DEFINE_integer("pass_hidden",0,"")
50+
flags.DEFINE_integer("seq_label",0,"")
51+
flags.DEFINE_integer("teacher_mode",0,"")
52+
flags.DEFINE_integer("disable_gcn",0,"")
53+
54+
# flags.DEFINE_integer("fold_id",0, "")
55+
56+
flags.DEFINE_integer("rnn_mode",1, "")
57+
flags.DEFINE_integer("decay_epoch",15, "Epoch to decay learning rate")
58+
flags.DEFINE_float("learning_rate",0.000125,"")
59+
# flags.DEFINE_float("learning_rate",0.0000625,"")
60+
61+
flags.DEFINE_string("gpu", '0', "gpu")
62+
FLAGS = flags.FLAGS
63+
64+
65+
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
66+
def main(_):
67+
pp.pprint(flags.FLAGS.__flags)
68+
69+
if not os.path.exists(FLAGS.checkpoint_dir):
70+
os.makedirs(FLAGS.checkpoint_dir)
71+
if not os.path.exists(FLAGS.save_checkpoint_dir) and not FLAGS.save_checkpoint_dir=="":
72+
os.makedirs(FLAGS.save_checkpoint_dir)
73+
74+
# cvt number to bool
75+
continue_train = False if FLAGS.continue_train==0 else True
76+
pass_hidden = False if FLAGS.pass_hidden==0 else True
77+
seq_label = False if FLAGS.seq_label==0 else True
78+
teacher_mode = False if FLAGS.teacher_mode==0 else True
79+
disable_gcn = False if FLAGS.disable_gcn==0 else True
80+
81+
color_table = load_color_table('./labels.json')
82+
83+
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
84+
# run_config = tf.ConfigProto(intra_op_parallelism_threads=1,inter_op_parallelism_threads=1)
85+
run_config = tf.ConfigProto()
86+
run_config.gpu_options.allow_growth=True
87+
tf.reset_default_graph()
88+
tf.set_random_seed(overall_random_seed)
89+
with tf.Session(config=run_config) as sess:
90+
91+
net = DeepLab(
92+
sess,
93+
input_width=FLAGS.input_width,
94+
input_height=FLAGS.input_height,
95+
crop_width=FLAGS.crop_width,
96+
crop_height=FLAGS.crop_height,
97+
batch_size=FLAGS.batch_size,
98+
seed=FLAGS.seed,
99+
temporal_len=FLAGS.temporal_len,
100+
img_pattern=FLAGS.img_pattern,
101+
label_pattern=FLAGS.label_pattern,
102+
checkpoint_dir=FLAGS.checkpoint_dir,
103+
save_checkpoint_dir=FLAGS.save_checkpoint_dir,
104+
pretrain_dir=FLAGS.pretrain_dir,
105+
datasets=FLAGS.datasets,
106+
train_dataset=FLAGS.train_dataset,
107+
frame_dataset=FLAGS.frame_dataset,
108+
video_dir=FLAGS.video_dir,
109+
continue_train=continue_train, ###
110+
pass_hidden=pass_hidden,
111+
seq_label=seq_label,
112+
teacher_mode=teacher_mode,
113+
disable_gcn=disable_gcn,
114+
model_type=FLAGS.model_type,
115+
rnn_mode=FLAGS.rnn_mode,
116+
learning_rate=FLAGS.learning_rate,
117+
# fold_id=FLAGS.fold_id, ###
118+
num_class=2,
119+
color_table=color_table,
120+
test_video=False,is_train=True)
121+
122+
net.train(FLAGS)
123+
124+
125+
126+
127+
if __name__ == '__main__':
128+
tf.app.run()

‎train.sh

+581
Large diffs are not rendered by default.

‎utils.py

+310
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
"""
2+
Some codes from https://github.com/Newmu/dcgan_code
3+
"""
4+
from __future__ import division
5+
import math
6+
import json
7+
import random
8+
import pprint
9+
import scipy.misc
10+
import numpy as np
11+
from time import gmtime, strftime
12+
from six.moves import xrange
13+
import cv2
14+
import os
15+
from image_inpaint import *
16+
17+
import tensorflow as tf
18+
import tensorflow.contrib.slim as slim
19+
pp = pprint.PrettyPrinter()
20+
21+
get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
22+
23+
24+
def load_color_table(json_file):
25+
# load color table
26+
f= open(json_file, "r", encoding='utf-8')
27+
colors = json.loads(f.read())
28+
class_num=len(colors)
29+
R,G,B=[[],[],[]]
30+
for c in colors:
31+
R.append(c['color'][0])
32+
G.append(c['color'][1])
33+
B.append(c['color'][2])
34+
return [R,G,B]
35+
36+
def idxmap2colormap(im_idx,color_table):
37+
R,G,B = color_table
38+
class_num = len(R)
39+
imR = np.zeros_like(im_idx,np.uint8)
40+
imG = np.zeros_like(im_idx,np.uint8)
41+
imB = np.zeros_like(im_idx,np.uint8)
42+
for i in range(class_num):
43+
imR[im_idx==i]=R[i]
44+
imG[im_idx==i]=G[i]
45+
imB[im_idx==i]=B[i]
46+
imcolor = np.dstack((imR,imG,imB))
47+
return imcolor
48+
49+
def show_all_variables():
50+
model_vars = tf.trainable_variables()
51+
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
52+
53+
def save_images(images, size, image_path):
54+
return imsave(images, size, image_path)
55+
56+
def imread(path,resize_wh=None, nearest_interpolate=False, grayscale = False):
57+
# print("#######",path)
58+
image = cv2.imread(path)
59+
if grayscale and image.shape[2]>0:
60+
image = image[:,:,0]
61+
if resize_wh is not None:
62+
if nearest_interpolate:
63+
image = cv2.resize(image,resize_wh,interpolation=cv2.INTER_NEAREST)
64+
else:
65+
image = cv2.resize(image,resize_wh)
66+
return image
67+
68+
# read from folder
69+
def sequence_read(path_train, dir_frame, temporal_len, interval=2, resize_wh=None, nearest_interpolate=False, grayscale = False):
70+
file = os.path.basename(path_train)
71+
vname,idx = file[:-4].split('_')
72+
if(os.path.exists(path_train)):
73+
frames=[imread(path_train, resize_wh, nearest_interpolate, grayscale)]
74+
else:
75+
path_train = os.path.join(dir_frame,vname+'_'+idx+file[-4:])
76+
frames=[imread(path_train, resize_wh, nearest_interpolate, grayscale)]
77+
# print("$$$$$$$0",len(frames),path_train,temporal_len)
78+
for t in range(1,temporal_len):
79+
idxt = str(int(idx)-interval*t)
80+
patht = os.path.join(dir_frame,vname+'_'+idxt+file[-4:])
81+
if(os.path.exists(patht)):
82+
img = imread(patht, resize_wh, nearest_interpolate, grayscale)
83+
frames.append(img)
84+
else:
85+
# print("iamhere",t)
86+
break
87+
# print("$$$$$$$1",patht)
88+
# print("$$$$$$$0",len(frames),vname+'_'+idxt+file[-4:],file)
89+
if len(frames) == temporal_len:
90+
# print("&&&&&&good")
91+
return frames
92+
else:
93+
# print("&&&&&&bad")
94+
interval = -interval
95+
frames=[imread(path_train, resize_wh, nearest_interpolate, grayscale)]
96+
for t in range(1,temporal_len):
97+
idxt = str(int(idx)-interval*t)
98+
patht = os.path.join(dir_frame,vname+'_'+idxt+file[-4:])
99+
if(os.path.exists(patht)):
100+
img = imread(patht, resize_wh, nearest_interpolate, grayscale)
101+
frames.append(img)
102+
if len(frames) == temporal_len:
103+
return frames
104+
else:
105+
return None
106+
107+
def full_sequence_read(imgfile, labelfile, temporal_len, resize_wh=None, nearest_interpolate=False, grayscale = False):
108+
if "EP" in os.path.basename(imgfile):
109+
inpaint_dir = "./get_miccai_dataset/inpaint_images"
110+
else:
111+
inpaint_dir = "../sinus_data/cadaver/inpaint_images"
112+
frames, labels, gt_valid_id = inpaint_image(imgfile, labelfile, inpaint_dir, temporal_len, resize_wh=resize_wh)
113+
return frames, labels
114+
115+
# def full_sequence_read(imgfile, labelfile, temporal_len, resize_wh=None):
116+
# syn_path = "./syn_images"
117+
# _imgfile = os.path.join(syn_path,os.path.basename(imgfile))
118+
# _labelfile = os.path.join(syn_path,os.path.basename(labelfile))
119+
# # print(cv2.imread(_imgfile).shape,cv2.imread(_labelfile,0).shape)
120+
# frames = np.reshape(cv2.imread(_imgfile),(temporal_len,resize_wh[1],resize_wh[0],3))
121+
# labels = np.reshape(cv2.imread(_labelfile,0),(temporal_len,resize_wh[1],resize_wh[0]))
122+
# # cv2.imwrite(os.path.join("./samples",os.path.basename(imgfile)),np.concatenate(frames,axis=0))
123+
# # cv2.imwrite(os.path.join("./samples",os.path.basename(labelfile)),np.concatenate(labels,axis=0)*255)
124+
# return frames, labels
125+
126+
def imsave(images, size, path):
127+
image = np.squeeze(merge(images, size))
128+
return scipy.misc.imsave(path, image)
129+
130+
def center_crop(x, crop_h, crop_w,
131+
resize_h=64, resize_w=64):
132+
if crop_w is None:
133+
crop_w = crop_h
134+
h, w = x.shape[:2]
135+
j = int(round((h - crop_h)/2.))
136+
i = int(round((w - crop_w)/2.))
137+
return scipy.misc.imresize(
138+
x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
139+
140+
def evaluate_seg_result(result_path, label_path, save_name='test_rst.txt', cum_time=None):
141+
dices = []
142+
ious = []
143+
ct_dices = []
144+
ct_ious = []
145+
names=[]
146+
# files=os.listdir(label_path)
147+
files=os.listdir(result_path)
148+
for file in files:
149+
if not file.endswith(".png"):
150+
continue
151+
152+
#
153+
gt = cv2.imread(os.path.join(label_path,file))
154+
155+
gt = gt[:,:,0]
156+
157+
if 'EP' in file:
158+
gt[gt>0]=1
159+
160+
## coutour loss
161+
contour_mask = np.zeros_like(gt)
162+
try:
163+
contours,_ = cv2.findContours(gt*255,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE)
164+
except:
165+
_,contours,_ = cv2.findContours(gt*255,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE)
166+
cv2.drawContours(contour_mask,contours,-1,(1,1,1),20)
167+
#
168+
ct_gt = gt*contour_mask
169+
170+
#
171+
output = cv2.imread(os.path.join(result_path,file))
172+
output = cv2.resize(output,(gt.shape[1],gt.shape[0]),interpolation=cv2.INTER_NEAREST)
173+
174+
output=output[:,:,1]/255
175+
#
176+
ct_output = output*contour_mask
177+
178+
179+
#
180+
if (np.count_nonzero(output)+np.count_nonzero(gt)) is 0:
181+
dice = 1
182+
iou = 1
183+
else:
184+
dice = (2*np.count_nonzero(gt*output))/(np.count_nonzero(output)+np.count_nonzero(gt)+0.000001)
185+
186+
iou = np.count_nonzero(gt*output)/(np.count_nonzero(output+gt)+0.000001)
187+
#
188+
if (np.count_nonzero(ct_output)+np.count_nonzero(ct_gt)) is 0:
189+
ct_dice = 1
190+
ct_iou = 1
191+
else:
192+
ct_dice = (2*np.count_nonzero(ct_gt*ct_output))/(np.count_nonzero(ct_output)+np.count_nonzero(ct_gt)+0.000001)
193+
194+
ct_iou = np.count_nonzero(ct_gt*ct_output)/(np.count_nonzero(ct_output+ct_gt)+0.000001)
195+
196+
197+
198+
dices.append(dice)
199+
ious.append(iou)
200+
ct_dices.append(ct_dice)
201+
ct_ious.append(ct_iou)
202+
names.append(file[:-4])
203+
204+
205+
mean_dice = np.mean(dices)
206+
mean_iou = np.mean(ious)
207+
ct_mean_dice = np.mean(ct_dices)
208+
ct_mean_iou = np.mean(ct_ious)
209+
210+
mean_time = np.mean(cum_time)
211+
num_time = len(cum_time)
212+
213+
print("mean_dice={},mean_iou={},ct_mean_dice={},ct_mean_iou={}".format(mean_dice,mean_iou,ct_mean_dice,ct_mean_iou))
214+
print("mean time: {}ms".format(mean_time))
215+
file = open(save_name, 'w')
216+
file.write("mean_dice={},mean_iou={},ct_mean_dice={},ct_mean_iou={},mean_time={},num_time={}\n".format(mean_dice,mean_iou,ct_mean_dice,ct_mean_iou,mean_time,num_time))
217+
file.close()
218+
219+
220+
def bilinear_sampler(imgs, coords):
221+
"""Construct a new image by bilinear sampling from the input image.
222+
223+
Points falling outside the source image boundary have value 0.
224+
225+
Args:
226+
imgs: source image to be sampled from [batch, height_s, width_s, channels]
227+
coords: coordinates of source pixels to sample from [batch, height_t,
228+
width_t, 2]. height_t/width_t correspond to the dimensions of the output
229+
image (don't need to be the same as height_s/width_s). The two channels
230+
correspond to x and y coordinates respectively.
231+
Returns:
232+
A new sampled image [batch, height_t, width_t, channels]
233+
"""
234+
def _repeat(x, n_repeats):
235+
rep = tf.transpose(
236+
tf.expand_dims(tf.ones(shape=tf.stack([
237+
n_repeats,
238+
])), 1), [1, 0])
239+
rep = tf.cast(rep, 'float32')
240+
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
241+
return tf.reshape(x, [-1])
242+
243+
with tf.name_scope('image_sampling'):
244+
coords_x, coords_y = tf.split(coords, [1, 1], axis=3)
245+
inp_size = imgs.get_shape()
246+
coord_size = coords.get_shape()
247+
out_size = coords.get_shape().as_list()
248+
out_size[3] = imgs.get_shape().as_list()[3]
249+
250+
coords_x = tf.cast(coords_x, 'float32')
251+
coords_y = tf.cast(coords_y, 'float32')
252+
253+
x0 = tf.floor(coords_x)
254+
x1 = x0 + 1
255+
y0 = tf.floor(coords_y)
256+
y1 = y0 + 1
257+
258+
y_max = tf.cast(tf.shape(imgs)[1] - 1, 'float32')
259+
x_max = tf.cast(tf.shape(imgs)[2] - 1, 'float32')
260+
zero = tf.zeros([1], dtype='float32')
261+
262+
x0_safe = tf.clip_by_value(x0, zero, x_max)
263+
y0_safe = tf.clip_by_value(y0, zero, y_max)
264+
x1_safe = tf.clip_by_value(x1, zero, x_max)
265+
y1_safe = tf.clip_by_value(y1, zero, y_max)
266+
267+
## bilinear interp weights, with points outside the grid having weight 0
268+
wt_x0 = (x1 - coords_x) * tf.cast(tf.equal(x0, x0_safe), 'float32')
269+
wt_x1 = (coords_x - x0) * tf.cast(tf.equal(x1, x1_safe), 'float32')
270+
wt_y0 = (y1 - coords_y) * tf.cast(tf.equal(y0, y0_safe), 'float32')
271+
wt_y1 = (coords_y - y0) * tf.cast(tf.equal(y1, y1_safe), 'float32')
272+
273+
## indices in the flat image to sample from
274+
dim2 = tf.cast(inp_size[2], 'float32')
275+
dim1 = tf.cast(inp_size[2] * inp_size[1], 'float32')
276+
base = tf.reshape(
277+
_repeat(
278+
tf.cast(tf.range(coord_size[0]), 'float32') * dim1,
279+
coord_size[1] * coord_size[2]),
280+
[out_size[0], out_size[1], out_size[2], 1])
281+
282+
base_y0 = base + y0_safe * dim2
283+
base_y1 = base + y1_safe * dim2
284+
idx00 = tf.reshape(x0_safe + base_y0, [-1])
285+
idx01 = x0_safe + base_y1
286+
idx10 = x1_safe + base_y0
287+
idx11 = x1_safe + base_y1
288+
289+
## sample from imgs
290+
#import pdb;pdb.set_trace()
291+
imgs_flat = tf.reshape(imgs, tf.stack([-1, inp_size[3]]))
292+
imgs_flat = tf.cast(imgs_flat, 'float32')
293+
im00 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx00, 'int32')), out_size)
294+
im01 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx01, 'int32')), out_size)
295+
im10 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx10, 'int32')), out_size)
296+
im11 = tf.reshape(tf.gather(imgs_flat, tf.cast(idx11, 'int32')), out_size)
297+
298+
w00 = wt_x0 * wt_y0
299+
w01 = wt_x0 * wt_y1
300+
w10 = wt_x1 * wt_y0
301+
w11 = wt_x1 * wt_y1
302+
303+
output = tf.add_n([
304+
w00 * im00, w01 * im01,
305+
w10 * im10, w11 * im11
306+
])
307+
308+
wmask = w00+w01+w10+w11
309+
310+
return output,wmask

0 commit comments

Comments
 (0)
Please sign in to comment.