Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dd61776
test
Apr 13, 2020
6c17fa1
test
Apr 13, 2020
2b2a8fb
test
YudongYao Apr 13, 2020
a42b5d2
Merge pull request #1 from tomography/master
YudongYao Apr 21, 2020
1adce5d
Merge pull request #2 from tomography/master
YudongYao Apr 29, 2020
b105101
test
YudongYao Apr 30, 2020
f61c2e1
Merge branch 'master' of https://github.com/YudongYao/tike
YudongYao Apr 30, 2020
5115619
Merge pull request #3 from tomography/master
YudongYao May 6, 2020
90f9761
Merge pull request #4 from tomography/master
YudongYao May 13, 2020
207cbae
Merge pull request #5 from tomography/master
YudongYao May 21, 2020
48ff26e
Merge pull request #6 from tomography/master
YudongYao May 26, 2020
70e27f1
Merge pull request #7 from tomography/master
YudongYao May 31, 2020
5e27ab9
Merge pull request #8 from tomography/master
YudongYao Jun 1, 2020
f0e9011
Merge pull request #9 from tomography/master
YudongYao Jun 11, 2020
7acb83b
multi-wavelength probe generation
YudongYao Jun 12, 2020
b666a0d
add another diemension for energy
YudongYao Jun 12, 2020
6ccc5d0
add another loop for energy in position correction
YudongYao Jun 12, 2020
f9d0714
add another loop for energy in probe update
YudongYao Jun 12, 2020
ee0eab3
add another dimension for energy
YudongYao Jun 12, 2020
d4c2385
add another dimension for energy
YudongYao Jun 12, 2020
2b95dc8
add another dimension for energy
YudongYao Jun 12, 2020
7e0f898
add another dimension for energy
YudongYao Jun 12, 2020
3d476ab
test on simulated data
YudongYao Jun 12, 2020
f2b7772
test multi-energy with simulation
YudongYao Jun 12, 2020
1744b9c
add another dimension for energy
YudongYao Jun 12, 2020
e8a4c54
fix some format issue
YudongYao Jun 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/tike/operators/numpy/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class Convolution(Operator):
psi : (ntheta, nz, n) complex64
The complex wavefront modulation of the object.
probe : complex64
The (ntheta, nscan // fly, fly, 1, probe_shape, probe_shape)
The (ntheta, nscan // fly, fly, energy, 1, probe_shape, probe_shape)
complex illumination function.
nearplane: complex64
The (ntheta, nscan // fly, fly, 1, probe_shape, probe_shape)
The (ntheta, nscan // fly, fly, energy, 1, probe_shape, probe_shape)
wavefronts after exiting the object.
scan : (ntheta, nscan, 2) float32
Coordinates of the minimum corner of the probe grid for each
Expand All @@ -40,12 +40,13 @@ class Convolution(Operator):

"""

def __init__(self, probe_shape, nz, n, ntheta, fly=1,
def __init__(self, probe_shape, nz, n, ntheta, energy=1, fly=1,
detector_shape=None, **kwargs): # yapf: disable
self.probe_shape = probe_shape
self.nz = nz
self.n = n
self.ntheta = ntheta
self.energy = energy
self.fly = fly
if detector_shape is None:
self.detector_shape = probe_shape
Expand All @@ -69,7 +70,7 @@ def fwd(self, psi, scan, probe):
)
patches = self._patch(patches, psi, scan, fwd=True)
patches = patches.reshape(self.ntheta, scan.shape[-2] // self.fly,
self.fly, 1, self.detector_shape,
self.fly, 1, 1, self.detector_shape,
self.detector_shape)
patches[..., self.pad:self.end, self.pad:self.end] *= probe
return patches
Expand Down Expand Up @@ -97,7 +98,7 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False):
)
patches = self._patch(patches, psi, scan, fwd=True)
patches = patches.reshape(self.ntheta, scan.shape[-2] // self.fly,
self.fly, 1, self.probe_shape,
self.fly, 1, 1, self.probe_shape,
self.probe_shape)
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]
Expand All @@ -107,18 +108,18 @@ def _check_shape_probe(self, x, nscan):
"""Check that the probe is correctly shaped."""
assert type(x) is self.xp.ndarray, type(x)
# unique probe for each position
shape1 = (self.ntheta, nscan // self.fly, self.fly, 1, self.probe_shape,
self.probe_shape)
shape1 = (self.ntheta, nscan // self.fly, self.fly, 1, 1,
self.probe_shape, self.probe_shape)
# one probe for all positions
shape2 = (self.ntheta, 1, 1, 1, self.probe_shape, self.probe_shape)
shape2 = (self.ntheta, 1, 1, 1, 1, self.probe_shape, self.probe_shape)
if __debug__ and x.shape != shape2 and x.shape != shape1:
raise ValueError(
f"probe must have shape {shape1} or {shape2} not {x.shape}")

def _check_shape_nearplane(self, x, nscan):
"""Check that nearplane is correctly shaped."""
assert type(x) is self.xp.ndarray, type(x)
shape1 = (self.ntheta, nscan // self.fly, self.fly, 1,
shape1 = (self.ntheta, nscan // self.fly, self.fly, 1, 1,
self.detector_shape, self.detector_shape)
if __debug__ and x.shape != shape1:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions src/tike/operators/numpy/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy


class Operator(ABC):
"""A base class for Operators.

Expand Down
6 changes: 3 additions & 3 deletions src/tike/operators/numpy/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Propagation(Operator):
farplane: (..., detector_shape, detector_shape) complex64
The wavefronts hitting the detector respectively.
Shape for cost functions and gradients is
(ntheta, nscan // fly, fly, 1, detector_shape, detector_shape).
(ntheta, nscan // fly, fly, 1, 1, detector_shape, detector_shape).
data, intensity : (ntheta, nscan, detector_shape, detector_shape) complex64
data is the square of the absolute value of `farplane`. `data` is the
intensity of the `farplane`.
Expand Down Expand Up @@ -86,12 +86,12 @@ def _gaussian_cost(self, data, intensity):
def _gaussian_grad(self, data, farplane, intensity, overwrite=False):
return farplane * (
1 - np.sqrt(data) / (np.sqrt(intensity) + 1e-32)
)[:, :, np.newaxis, np.newaxis] # yapf:disable
)[:, :, np.newaxis, np.newaxis, np.newaxis] # yapf:disable

def _poisson_cost(self, data, intensity):
return np.sum(intensity - data * np.log(intensity + 1e-32))

def _poisson_grad(self, data, farplane, intensity, overwrite=False):
return farplane * (
1 - data / (intensity + 1e-32)
)[:, :, np.newaxis, np.newaxis] # yapf: disable
)[:, :, np.newaxis, np.newaxis, np.newaxis] # yapf: disable
60 changes: 32 additions & 28 deletions src/tike/operators/numpy/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ class Ptycho(Operator):
psi : (ntheta, nz, n) complex64
The complex wavefront modulation of the object.
probe : complex64
The complex (ntheta, nscan // fly, fly, 1, probe_shape,
The complex (ntheta, nscan // fly, fly, 1, 1, probe_shape,
probe_shape) illumination function.
mode : complex64
A single (ntheta, nscan // fly, fly, 1, probe_shape, probe_shape)
A single (ntheta, nscan // fly, fly,1, 1, probe_shape, probe_shape)
probe mode.
nearplane, farplane: complex64
The (ntheta, nscan // fly, fly, 1, detector_shape, detector_shape)
Expand Down Expand Up @@ -128,43 +128,47 @@ def adj_probe(self, farplane, scan, psi, overwrite=False, **kwargs):
overwrite=True,
)

def _compute_intensity(self, data, psi, scan, probe, n=-1, mode=None):
def _compute_intensity(self, data, psi, scan, probe, t=-1, n=-1, mode=None):
"""Compute detector intensities replacing the nth probe mode"""
intensity = 0
for m in range(probe.shape[-3]):
intensity += np.sum(
np.square(np.abs(self.fwd(
psi=psi,
scan=scan,
probe=mode if m == n else probe[..., m:m + 1, :, :],
).reshape(*data.shape[:2], -1, *data.shape[2:]))),
axis=2,
) # yapf: disable
for w in range(probe.shape[-4]):
for m in range(probe.shape[-3]):
intensity += np.sum(
np.square(np.abs(self.fwd(
psi=psi,
scan=scan,
probe = mode if (w == t and m == n) else probe[..., w:w+1, m:m + 1, :, :],
).reshape(*data.shape[:2], -1, *data.shape[2:]))),
axis=2,
) # yapf: disable
return intensity

def cost(self, data, psi, scan, probe, n=-1, mode=None):
intensity = self._compute_intensity(data, psi, scan, probe, n, mode)
def cost(self, data, psi, scan, probe, t=-1, n=-1, mode=None):
intensity = self._compute_intensity(data, psi, scan, probe, t, n, mode)
return self.propagation.cost(data, intensity)

def grad(self, data, psi, scan, probe):
intensity = self._compute_intensity(data, psi, scan, probe)
grad_obj = self.xp.zeros_like(psi)
for mode in np.split(probe, probe.shape[-3], axis=-3):
# TODO: Pass obj through adj() instead of making new obj inside
grad_obj += self.adj(
farplane=self.propagation.grad(
data,
self.fwd(psi=psi, scan=scan, probe=mode),
intensity,
),
probe=mode,
scan=scan,
overwrite=True,
)
for i in range(probe.shape[-4]):
for mode in np.split(probe[..., i:i + 1, :, :, :],
probe[..., i:i + 1, :, :, :].shape[-3],
axis=-3):
# TODO: Pass obj through adj() instead of making new obj inside
grad_obj += self.adj(
farplane=self.propagation.grad(
data,
self.fwd(psi=psi, scan=scan, probe=mode),
intensity,
),
probe=mode,
scan=scan,
overwrite=True,
)
return grad_obj

def grad_probe(self, data, psi, scan, probe, n=-1, mode=None):
intensity = self._compute_intensity(data, psi, scan, probe, n, mode)
def grad_probe(self, data, psi, scan, probe, t=-1, n=-1, mode=None):
intensity = self._compute_intensity(data, psi, scan, probe, t, n, mode)
return self.adj_probe(
farplane=self.propagation.grad(
data,
Expand Down
9 changes: 6 additions & 3 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,23 @@ def update_positions_pd(operator, data, psi, probe, scan,
dI = (data - intensity).reshape(*data.shape[:-2], np.prod(data.shape[-2:]))

dI_dx, dI_dy = 0, 0

#update position use the central wavelength only
i = 1
for m in range(probe.shape[-3]):

# step 2: the partial derivatives of wavefront respect to position
farplane = operator.fwd(psi=psi,
scan=scan,
probe=probe[..., m:m + 1, :, :])
probe=probe[..., i:i + 1, m:m + 1, :, :])
dfarplane_dx = (farplane - operator.fwd(
psi=psi,
probe=probe[..., m:m + 1, :, :],
probe=probe[..., i:i + 1, m:m + 1, :, :],
scan=scan + operator.xp.array((0, dx), dtype='float32'),
)) / dx
dfarplane_dy = (farplane - operator.fwd(
psi=psi,
probe=probe[..., m:m + 1, :, :],
probe=probe[..., i:i + 1, m:m + 1, :, :],
scan=scan + operator.xp.array((dx, 0), dtype='float32'),
)) / dx

Expand Down
159 changes: 159 additions & 0 deletions src/tike/ptycho/probe_MW.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
##This is the python version code for initialize probes for multi-wavelength method

import numpy as np
from scipy.interpolate import RectBivariateSpline


def MW_probe(probe_shape, energy, dx_dec, dis_defocus, dis_StoD, **kwargs):
# return probe sorted by the spectrum
# return scale is the wavelength dependent pixel scaling factor
"""
Summary of this function goes here
Parameters: probe_shape -> the matrix size for probe
energy -> number of wavelength for multi-wavelength recontruction
dx_dec -> pixel size on detector
dis_defocus -> defocus distance of sample plane to the focus of the FZP
dis_StoD -> sample to detector distance
kwargs -> setup: 'velo','2idd'
-> spectrum: measured spectrum (if available) [wavelength,intensity]
-> bandwidth
-> lambda0: central wavelength used to generate spectrum if no measured one was given
"""

if 'spectrum' in kwargs:
spectrum = kwargs.get('spectrum')
spectrum = spectrum[:, ::spectrum.shape[1] // energy][:, :energy]
lambda0 = spectrum[np.argmax(spectrum[:, 1]), 0]
else:
if 'bandwidth' in kwargs:
bandwidth = kwargs.get('bandwidth')
else:
bandwidth = 0.01

if 'lambda0' in kwargs:
lambda0 = kwargs.get('lambda0')
else:
lambda0 = 1.24e-9 / 8.8
spectrum = gaussian_spectrum(lambda0, bandwidth, energy)

spectrum = spectrum[np.argsort(-spectrum[:, 1])]

if 'setup' in kwargs:
setup = kwargs.get('setup')
else:
setup = 'default'
probe = np.zeros((energy, 1, probe_shape, probe_shape), dtype=np.complex)

# pixel size on sample plane (central wavelength)
dx = spectrum[0, 0] * dis_StoD / probe_shape / dx_dec
# focal length for central wavelength
_, _, FL0 = fzp_calculate(spectrum[0, 0], dis_defocus, probe_shape, dx,
setup)

for i in range(energy):
# get zone plate parameter
T, dx_fzp, _ = fzp_calculate(spectrum[i, 0], dis_defocus, probe_shape,
dx, setup)
nprobe = fresnel_propagation(T, dx_fzp, (FL0 + dis_defocus),
spectrum[i, 0])
nprobe = nprobe / (np.sqrt(np.sum(np.abs(nprobe)**2)))
probe[i, 0, :, :] = nprobe * (np.sqrt(spectrum[i, 1]))

return probe[np.newaxis, np.newaxis, np.newaxis]


def gaussian_spectrum(lambda0, bandwidth, energy):
spectrum = np.zeros((energy, 2))
sigma = lambda0 * bandwidth / 2.355
d_lam = sigma * 4 / (energy - 1)
spectrum[:, 0] = np.arange(-1 * np.floor(energy / 2), np.ceil(
energy / 2)) * d_lam + lambda0
spectrum[:, 1] = np.exp(-(spectrum[:, 0] - lambda0)**2 / sigma**2)
return spectrum


def fzp_calculate(wavelength, dis_defocus, M, dx, setup):
"""
this function can calculate the transfer function of zone plate
return the transfer function, and the pixel sizes
"""

FZP_para = get_setup(setup)

FL = 2 * FZP_para['radius'] * FZP_para['outmost'] / wavelength

# pixel size on FZP plane
dx_fzp = wavelength * (FL + dis_defocus) / M / dx
# coordinate on FZP plane
lx_fzp = -dx_fzp * np.arange(-1 * np.floor(M / 2), np.ceil(M / 2))

XX_FZP, YY_FZP = np.meshgrid(lx_fzp, lx_fzp)
# transmission function of FZP
T = np.exp(-1j * 2 * np.pi / wavelength * (XX_FZP**2 + YY_FZP**2) / 2 / FL)
C = np.sqrt(XX_FZP**2 + YY_FZP**2) <= FZP_para['radius']
H = np.sqrt(XX_FZP**2 + YY_FZP**2) >= FZP_para['CS'] / 2

return T * C * H, dx_fzp, FL


def get_setup(setup):
switcher = {
'velo': {
'radius': 90e-6,
'outmost': 50e-9,
'CS': 60e-6
},
'2idd': {
'radius': 80e-6,
'outmost': 70e-9,
'CS': 60e-6
},
'default': {
'radius': 90e-6,
'outmost': 50e-9,
'CS': 60e-6
},
}
FZP_para = switcher.get(setup)
return FZP_para


def fresnel_propagation(input, dxy, z, wavelength):
"""
This is the python version code for fresnel propagation
Summary of this function goes here
Parameters: dx,dy -> the pixel pitch of the object
z -> the distance of the propagation
lambda -> the wave length
X,Y -> meshgrid of coordinate
input -> input object
"""

(M, N) = input.shape
k = 2 * np.pi / wavelength
# the coordinate grid
M_grid = np.arange(-1 * np.floor(M / 2), np.ceil(M / 2))
N_grid = np.arange(-1 * np.floor(N / 2), np.ceil(N / 2))
lx = M_grid * dxy
ly = N_grid * dxy

XX, YY = np.meshgrid(lx, ly)

# the coordinate grid on the output plane
fc = 1 / dxy
fu = wavelength * z * fc
lu = M_grid * fu / M
lv = N_grid * fu / N
Fx, Fy = np.meshgrid(lu, lv)

if z > 0:
pf = np.exp(1j * k * z) * np.exp(1j * k * (Fx**2 + Fy**2) / 2 / z)
kern = input * np.exp(1j * k * (XX**2 + YY**2) / 2 / z)
cgh = np.fft.fft2(np.fft.fftshift(kern))
OUT = np.fft.fftshift(cgh * np.fft.fftshift(pf))
else:
pf = np.exp(1j * k * z) * np.exp(1j * k * (XX**2 + YY**2) / 2 / z)
cgh = np.fft.ifft2(
np.fft.fftshift(input * np.exp(1j * k * (Fx**2 + Fy**2) / 2 / z)))
OUT = np.fft.fftshift(cgh) * pf
return OUT
Loading