-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAssistVisualization.py
338 lines (298 loc) · 14.9 KB
/
AssistVisualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import pydicom
import os,re
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from IPython.display import clear_output
import math
import cv2
p=256 #for dicom scaling
ESP=0.1 #for dicom scaling
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
# Create a figure with two subplots and the specified size.
plt.subplots(1,2,figsize=(10,8))
# Draw the fixed image in the first subplot.
plt.subplot(1,2,1)
plt.imshow(fixed_npa[fixed_image_z,:,:],cmap=plt.cm.Greys_r);
plt.title('fixed image')
plt.axis('off')
# Draw the moving image in the second subplot.
plt.subplot(1,2,2)
plt.imshow(moving_npa[moving_image_z,:,:],cmap=plt.cm.Greys_r);
plt.title('moving image')
plt.axis('off')
plt.show()
# Callback invoked by the IPython interact method for scrolling and modifying the alpha blending
# of an image stack of two images that occupy the same physical space.
def display_images_with_alpha(image_z, alpha, fixed, moving):
# img = (1.0 - alpha)*fixed[:,:,image_z] + alpha*moving[:,:,image_z]
fixed = sitk.GetArrayFromImage(fixed)
moving = sitk.GetArrayFromImage(moving)
dst = cv2.addWeighted(fixed[image_z,:,:],1.0-alpha,moving[image_z,:,:],alpha,0)
plt.figure(figsize=(5,5))
plt.imshow(dst, cmap=plt.cm.Greys_r)
# plt.imshow(sitk.GetArrayViewFromImage(img),cmap=plt.cm.Greys_r);
plt.axis('off')
plt.show()
def display_images_with_mask(image_z, fixed, moving):
# img = (1.0 - alpha)*fixed[:,:,image_z] + alpha*moving[:,:,image_z]
fixed = sitk.GetArrayFromImage(fixed)
moving = sitk.GetArrayFromImage(moving) # mask
dst = fixed[image_z,:,:]*0.5+moving[image_z,:,:]*0.5*255
plt.imshow(dst, cmap=plt.cm.Greys_r)
# plt.imshow(sitk.GetArrayViewFromImage(img),cmap=plt.cm.Greys_r);
plt.axis('off')
plt.show()
# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
global metric_values, multires_iterations
plt.figure()
metric_values = []
multires_iterations = []
# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
global metric_values, multires_iterations
del metric_values
del multires_iterations
# Close figure, we don't want to get a duplicate of the plot latter on.
plt.close()
# Callback invoked when the IterationEvent happens, update our data and display new figure.
def plot_values(registration_method):
global metric_values, multires_iterations
metric_values.append(registration_method.GetMetricValue())
# Clear the output area (wait=True, to reduce flickering), and plot current data
clear_output(wait=True)
# Plot the similarity metric values
plt.plot(metric_values, 'r')
plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
plt.xlabel('Iteration Number',fontsize=12)
plt.ylabel('Metric Value',fontsize=12)
plt.show()
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the
# metric_values list.
def update_multires_iterations():
global metric_values, multires_iterations
multires_iterations.append(len(metric_values))
def save_transform_and_image(transform, fixed_image, moving_image, fixed_ori,moving_ori, dicompath, outputfile_prefix,multi_tp):
global cnt_global
resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(fixed_image)
resample.SetInterpolator(sitk.sitkBSpline)
resample.SetTransform(transform)
#sitk.WriteImage(resample.Execute(moving_image), outputfile_prefix+'.mha')
#sitk.WriteTransform(transform, outputfile_prefix+'.tfm')
max_moving = 0
min_moving = 0
exqimg = sitk.GetArrayFromImage(resample.Execute(moving_image))
print(exqimg.shape)
fix_imgs = sitk.GetArrayFromImage(fixed_image)
for i in range(0,exqimg.shape[0]):
max_moving = max(max_moving,np.amax(exqimg[i,:,:]))
min_moving = min(min_moving,np.amin(exqimg[i,:,:]))
head1,tail1 = os.path.split(fixed_ori[0])
head2,tail2 = os.path.split(moving_ori[0])
#if ('S101' in tail1) and ('S104' in tail2):
# Aname = Adomainpath + str(cnt_global)+'.png'
# Bname = Bdomainpath + str(cnt_global)+'.png'
# plt.imsave(Aname, fix_imgs[i,:,:],cmap = plt.get_cmap('gray'))
#plt.imsave(Bname, exqimg[i,:,:],cmap = plt.get_cmap('gray'))
#cnt_global += 1
#print(Aname)
if multi_tp == False:
ds_ori = pydicom.dcmread(fixed_ori[0])
ds_ori_end = pydicom.dcmread(fixed_ori[-1])
sp_x = (ds_ori_end[0x20,0x32].value[0]-ds_ori[0x20,0x32].value[0])/len(fixed_ori)
sp_y = (ds_ori_end[0x20,0x32].value[1]-ds_ori[0x20,0x32].value[1])/len(fixed_ori)
sp_z = (ds_ori_end[0x20,0x32].value[2]-ds_ori[0x20,0x32].value[2])/len(fixed_ori)
a = ds_ori_end.SliceLocation-float(ds_ori_end[0x20,0x32].value[0])
b = ds_ori_end.SliceLocation-float(ds_ori_end[0x20,0x32].value[1])
c = ds_ori_end.SliceLocation-float(ds_ori_end[0x20,0x32].value[2])
spori = -math.sqrt(sp_x*sp_x+sp_y*sp_y+sp_z*sp_z)
slicelocation_axis = 'x'
if min(a,b,c)==b:
slicelocation_axis = 'y'
elif min(a,b,c)==c:
slicelocation_axis = 'z'
if len(moving_ori)<len(fixed_ori):
reg_img = resample.Execute(moving_image)
#reg_img_resample = zoom(reg_img, (1, 1, len(moving_ori)/len(fixed_ori)))
new_x_size = (reg_img.GetSize())[0]
new_y_size = (reg_img.GetSize())[1]
new_z_size = len(moving_ori) #downsample
new_size = [new_x_size, new_y_size, new_z_size]
new_spacing = [old_sz*old_spc/new_sz for old_sz, old_spc, new_sz in zip(reg_img.GetSize(), reg_img.GetSpacing(), new_size)]
interpolator_type = sitk.sitkLinear
reg_img_resample = sitk.Resample(reg_img, new_size, sitk.Transform(),\
interpolator_type, reg_img.GetOrigin(),\
new_spacing, reg_img.GetDirection(), 0.0, reg_img.GetPixelIDValue())
new_reg_img_resample = (reg_img_resample-min_moving)/(max_moving-min_moving)*p
for i in range(0,len(moving_ori)):
path = moving_ori[i]
ds = pydicom.dcmread(path)
head,tail = os.path.split(path)
#new_data = ds.pixel_array
new_data = sitk.GetArrayFromImage(new_reg_img_resample[:,:,i])
#new_data = (new_data-np.amin(new_data))/(np.amax(new_data)-np.amin(new_data)+ESP)*p
new_data[new_data<0]=0
new_data[new_data>p]=p
new_data = (new_data).astype('int16')
sp = -(spori)*(len(fixed_ori)/len(moving_ori))
xsp = (sp_x)*(len(fixed_ori)/len(moving_ori))
ysp = (sp_y)*(len(fixed_ori)/len(moving_ori))
zsp = (sp_z)*(len(fixed_ori)/len(moving_ori))
#print((float(ds_ori[0x18,0x88].value)),len(fixed_ori),len(moving_ori),sp)
ds[0x18,0x88].value = sp
ds[0x20,0x37].value = ds_ori[0x20,0x37].value
if slicelocation_axis == 'x':
ds.SliceLocation = ds_ori.SliceLocation+xsp*i
elif slicelocation_axis == 'y':
ds.SliceLocation = ds_ori.SliceLocation+ysp*i
elif slicelocation_axis == 'z':
ds.SliceLocation = ds_ori.SliceLocation+zsp*i
orientation = [ds_ori[0x20,0x32].value[0]+xsp*i,\
ds_ori[0x20,0x32].value[1]+ysp*i,ds_ori[0x20,0x32].value[2]+zsp*i]
ds[0x20,0x32].value = orientation
ds[0x28,0x10].value = ds_ori[0x28,0x10].value
ds[0x28,0x11].value = ds_ori[0x28,0x11].value
ds[0x28,0x30].value = ds_ori[0x28,0x30].value
ds.PixelData = new_data.tostring()
newpath = dicompath+ tail
ds.save_as(newpath)
print(newpath)
elif len(moving_ori)>=len(fixed_ori):
for i in range(0,len(moving_ori)):
path = moving_ori[i]
ds = pydicom.dcmread(path)
head,tail = os.path.split(path)
reg_img = resample.Execute(moving_image)
new_reg_img = (reg_img-min_moving)/(max_moving-min_moving)*p
#new_reg_img = reg_img
if i < len(fixed_ori):
ds_ori = pydicom.dcmread(fixed_ori[i])
#new_data = ds.pixel_array
new_data = sitk.GetArrayFromImage(new_reg_img[:,:,i])
#new_data = (new_data-np.amin(new_data))/(np.amax(new_data)-np.amin(new_data)+ESP)*p
new_data[new_data<0]=0
new_data[new_data>p]=p
new_data = (new_data).astype('int16')
ds[0x20,0x32].value = ds_ori[0x20,0x32].value
ds[0x20,0x37].value = ds_ori[0x20,0x37].value
ds.SliceLocation = ds_ori.SliceLocation
ds[0x28,0x10].value = ds_ori[0x28,0x10].value
ds[0x28,0x11].value = ds_ori[0x28,0x11].value
ds[0x28,0x30].value = ds_ori[0x28,0x30].value
ds.PixelData = new_data.tostring()
newpath = dicompath+ tail
ds.save_as(newpath)
print(newpath)
else:
ds_ori = pydicom.dcmread(fixed_ori[-1])
new_data = np.zeros((ds_ori.Rows,ds_ori.Columns))
new_data = (new_data).astype('int16')
sp = -(float(ds_ori[0x18,0x88].value))
xsp = (sp_x)*(len(fixed_ori)/len(moving_ori))
ysp = (sp_y)*(len(fixed_ori)/len(moving_ori))
zsp = (sp_z)*(len(fixed_ori)/len(moving_ori))
orientation = [ds_ori[0x20,0x32].value[0]+xsp*(i-len(fixed_ori)),\
ds_ori[0x20,0x32].value[1]+ysp*(i-len(fixed_ori))\
,ds_ori[0x20,0x32].value[2]+zsp*(i-len(fixed_ori))]
ds[0x20,0x32].value = orientation
ds[0x20,0x37].value = ds_ori[0x20,0x37].value
if slicelocation_axis == 'x':
ds.SliceLocation = ds_ori.SliceLocation+xsp*(i-len(fixed_ori))
elif slicelocation_axis == 'y':
ds.SliceLocation = ds_ori.SliceLocation+ysp*(i-len(fixed_ori))
elif slicelocation_axis == 'z':
ds.SliceLocation = ds_ori.SliceLocation+zsp*(i-len(fixed_ori))
#ds.SliceLocation = ds_ori.SliceLocation+sp*(i-len(fixed_ori))
ds[0x28,0x30].value = ds_ori[0x28,0x30].value
ds.Rows = ds_ori.Rows
ds.Columns = ds_ori.Columns
ds.PixelData = new_data.tostring()
newpath = dicompath+ tail
ds.save_as(newpath)
print(newpath)
# ds = pydicom.dcmread(newpath)
# img = ds.pixel_array
# plt.figure()
# plt.imshow(img)
# plt.show()
# plt.close()
else:
reg_img = resample.Execute(moving_image)
new_reg_img = (reg_img-min_moving)/(max_moving-min_moving)*p
for i in range(0,len(moving_ori)):
path = moving_ori[i]
ds = pydicom.dcmread(path)
head,tail = os.path.split(path)
ds_ori = pydicom.dcmread(fixed_ori[i])
#new_data = ds.pixel_array
new_data = sitk.GetArrayFromImage((new_reg_img)[:,:,i])
#new_data = (new_data-np.amin(new_data))/(np.amax(new_data)-np.amin(new_data)+ESP)*p
new_data[new_data<0]=0
new_data[new_data>p]=p
new_data = (new_data).astype('int16')
ds[0x20,0x32].value = ds_ori[0x20,0x32].value
ds[0x20,0x37].value = ds_ori[0x20,0x37].value
ds.SliceLocation = ds_ori.SliceLocation
ds[0x28,0x10].value = ds_ori[0x28,0x10].value
ds[0x28,0x11].value = ds_ori[0x28,0x11].value
ds[0x28,0x30].value = ds_ori[0x28,0x30].value
ds.PixelData = new_data.tostring()
newpath = dicompath+ tail
ds.save_as(newpath)
print(newpath)
# Callback we associate with the StartEvent, sets up our new data.
def metric_start_plot():
global metric_values, multires_iterations
global current_iteration_number
metric_values = []
multires_iterations = []
current_iteration_number = -1
# Callback we associate with the EndEvent, do cleanup of data and figure.
def metric_end_plot():
global metric_values, multires_iterations
global current_iteration_number
del metric_values
del multires_iterations
del current_iteration_number
# Close figure, we don't want to get a duplicate of the plot latter on
plt.close()
# Callback we associate with the IterationEvent, update our data and display
# new figure.
def metric_plot_values(registration_method):
global metric_values, multires_iterations
global current_iteration_number
# Some optimizers report an iteration event for function evaluations and not
# a complete iteration, we only want to update every iteration.
if registration_method.GetOptimizerIteration() == current_iteration_number:
return
current_iteration_number = registration_method.GetOptimizerIteration()
metric_values.append(registration_method.GetMetricValue())
# Clear the output area (wait=True, to reduce flickering), and plot
# current data.
clear_output(wait=True)
# Plot the similarity metric values.
plt.plot(metric_values, 'r')
plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
plt.xlabel('Iteration Number',fontsize=12)
plt.ylabel('Metric Value',fontsize=12)
plt.show()
# Callback we associate with the MultiResolutionIterationEvent, update the
# index into the metric_values list.
def metric_update_multires_iterations():
global metric_values, multires_iterations
multires_iterations.append(len(metric_values))
def sortdir(flist):
for file in flist:
if ('.dcm') not in file:
del flist[flist.index(file)]
for j in range(0,len(flist)):
for k in range(j+1,len(flist)):
num=int(re.split('E|S|I|.dcm',flist[j])[3])
num2=int(re.split('E|S|I|.dcm',flist[k])[3])
if num>num2:
tmp=flist[j]
flist[j]=flist[k]
flist[k]=tmp
return flist