diff --git a/src/tike/operators/numpy/convolution.py b/src/tike/operators/numpy/convolution.py index cca6a1eb..f89fa266 100644 --- a/src/tike/operators/numpy/convolution.py +++ b/src/tike/operators/numpy/convolution.py @@ -28,10 +28,10 @@ class Convolution(Operator): psi : (ntheta, nz, n) complex64 The complex wavefront modulation of the object. probe : complex64 - The (ntheta, nscan // fly, fly, 1, probe_shape, probe_shape) + The (ntheta, nscan // fly, fly, energy, 1, probe_shape, probe_shape) complex illumination function. nearplane: complex64 - The (ntheta, nscan // fly, fly, 1, probe_shape, probe_shape) + The (ntheta, nscan // fly, fly, energy, 1, probe_shape, probe_shape) wavefronts after exiting the object. scan : (ntheta, nscan, 2) float32 Coordinates of the minimum corner of the probe grid for each @@ -40,12 +40,13 @@ class Convolution(Operator): """ - def __init__(self, probe_shape, nz, n, ntheta, fly=1, + def __init__(self, probe_shape, nz, n, ntheta, energy=1, fly=1, detector_shape=None, **kwargs): # yapf: disable self.probe_shape = probe_shape self.nz = nz self.n = n self.ntheta = ntheta + self.energy = energy self.fly = fly if detector_shape is None: self.detector_shape = probe_shape @@ -69,7 +70,7 @@ def fwd(self, psi, scan, probe): ) patches = self._patch(patches, psi, scan, fwd=True) patches = patches.reshape(self.ntheta, scan.shape[-2] // self.fly, - self.fly, 1, self.detector_shape, + self.fly, 1, 1, self.detector_shape, self.detector_shape) patches[..., self.pad:self.end, self.pad:self.end] *= probe return patches @@ -97,7 +98,7 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False): ) patches = self._patch(patches, psi, scan, fwd=True) patches = patches.reshape(self.ntheta, scan.shape[-2] // self.fly, - self.fly, 1, self.probe_shape, + self.fly, 1, 1, self.probe_shape, self.probe_shape) patches = patches.conj() patches *= nearplane[..., self.pad:self.end, self.pad:self.end] @@ -107,10 +108,10 @@ def _check_shape_probe(self, x, nscan): """Check that the probe is correctly shaped.""" assert type(x) is self.xp.ndarray, type(x) # unique probe for each position - shape1 = (self.ntheta, nscan // self.fly, self.fly, 1, self.probe_shape, - self.probe_shape) + shape1 = (self.ntheta, nscan // self.fly, self.fly, 1, 1, + self.probe_shape, self.probe_shape) # one probe for all positions - shape2 = (self.ntheta, 1, 1, 1, self.probe_shape, self.probe_shape) + shape2 = (self.ntheta, 1, 1, 1, 1, self.probe_shape, self.probe_shape) if __debug__ and x.shape != shape2 and x.shape != shape1: raise ValueError( f"probe must have shape {shape1} or {shape2} not {x.shape}") @@ -118,7 +119,7 @@ def _check_shape_probe(self, x, nscan): def _check_shape_nearplane(self, x, nscan): """Check that nearplane is correctly shaped.""" assert type(x) is self.xp.ndarray, type(x) - shape1 = (self.ntheta, nscan // self.fly, self.fly, 1, + shape1 = (self.ntheta, nscan // self.fly, self.fly, 1, 1, self.detector_shape, self.detector_shape) if __debug__ and x.shape != shape1: raise ValueError( diff --git a/src/tike/operators/numpy/operator.py b/src/tike/operators/numpy/operator.py index ebb23e85..77f30d62 100644 --- a/src/tike/operators/numpy/operator.py +++ b/src/tike/operators/numpy/operator.py @@ -2,6 +2,7 @@ import numpy + class Operator(ABC): """A base class for Operators. diff --git a/src/tike/operators/numpy/propagation.py b/src/tike/operators/numpy/propagation.py index 982925f6..6bb5eb61 100644 --- a/src/tike/operators/numpy/propagation.py +++ b/src/tike/operators/numpy/propagation.py @@ -30,7 +30,7 @@ class Propagation(Operator): farplane: (..., detector_shape, detector_shape) complex64 The wavefronts hitting the detector respectively. Shape for cost functions and gradients is - (ntheta, nscan // fly, fly, 1, detector_shape, detector_shape). + (ntheta, nscan // fly, fly, 1, 1, detector_shape, detector_shape). data, intensity : (ntheta, nscan, detector_shape, detector_shape) complex64 data is the square of the absolute value of `farplane`. `data` is the intensity of the `farplane`. @@ -86,7 +86,7 @@ def _gaussian_cost(self, data, intensity): def _gaussian_grad(self, data, farplane, intensity, overwrite=False): return farplane * ( 1 - np.sqrt(data) / (np.sqrt(intensity) + 1e-32) - )[:, :, np.newaxis, np.newaxis] # yapf:disable + )[:, :, np.newaxis, np.newaxis, np.newaxis] # yapf:disable def _poisson_cost(self, data, intensity): return np.sum(intensity - data * np.log(intensity + 1e-32)) @@ -94,4 +94,4 @@ def _poisson_cost(self, data, intensity): def _poisson_grad(self, data, farplane, intensity, overwrite=False): return farplane * ( 1 - data / (intensity + 1e-32) - )[:, :, np.newaxis, np.newaxis] # yapf: disable + )[:, :, np.newaxis, np.newaxis, np.newaxis] # yapf: disable diff --git a/src/tike/operators/numpy/ptycho.py b/src/tike/operators/numpy/ptycho.py index 893bd734..2c8271c7 100644 --- a/src/tike/operators/numpy/ptycho.py +++ b/src/tike/operators/numpy/ptycho.py @@ -40,10 +40,10 @@ class Ptycho(Operator): psi : (ntheta, nz, n) complex64 The complex wavefront modulation of the object. probe : complex64 - The complex (ntheta, nscan // fly, fly, 1, probe_shape, + The complex (ntheta, nscan // fly, fly, 1, 1, probe_shape, probe_shape) illumination function. mode : complex64 - A single (ntheta, nscan // fly, fly, 1, probe_shape, probe_shape) + A single (ntheta, nscan // fly, fly,1, 1, probe_shape, probe_shape) probe mode. nearplane, farplane: complex64 The (ntheta, nscan // fly, fly, 1, detector_shape, detector_shape) @@ -128,43 +128,47 @@ def adj_probe(self, farplane, scan, psi, overwrite=False, **kwargs): overwrite=True, ) - def _compute_intensity(self, data, psi, scan, probe, n=-1, mode=None): + def _compute_intensity(self, data, psi, scan, probe, t=-1, n=-1, mode=None): """Compute detector intensities replacing the nth probe mode""" intensity = 0 - for m in range(probe.shape[-3]): - intensity += np.sum( - np.square(np.abs(self.fwd( - psi=psi, - scan=scan, - probe=mode if m == n else probe[..., m:m + 1, :, :], - ).reshape(*data.shape[:2], -1, *data.shape[2:]))), - axis=2, - ) # yapf: disable + for w in range(probe.shape[-4]): + for m in range(probe.shape[-3]): + intensity += np.sum( + np.square(np.abs(self.fwd( + psi=psi, + scan=scan, + probe = mode if (w == t and m == n) else probe[..., w:w+1, m:m + 1, :, :], + ).reshape(*data.shape[:2], -1, *data.shape[2:]))), + axis=2, + ) # yapf: disable return intensity - def cost(self, data, psi, scan, probe, n=-1, mode=None): - intensity = self._compute_intensity(data, psi, scan, probe, n, mode) + def cost(self, data, psi, scan, probe, t=-1, n=-1, mode=None): + intensity = self._compute_intensity(data, psi, scan, probe, t, n, mode) return self.propagation.cost(data, intensity) def grad(self, data, psi, scan, probe): intensity = self._compute_intensity(data, psi, scan, probe) grad_obj = self.xp.zeros_like(psi) - for mode in np.split(probe, probe.shape[-3], axis=-3): - # TODO: Pass obj through adj() instead of making new obj inside - grad_obj += self.adj( - farplane=self.propagation.grad( - data, - self.fwd(psi=psi, scan=scan, probe=mode), - intensity, - ), - probe=mode, - scan=scan, - overwrite=True, - ) + for i in range(probe.shape[-4]): + for mode in np.split(probe[..., i:i + 1, :, :, :], + probe[..., i:i + 1, :, :, :].shape[-3], + axis=-3): + # TODO: Pass obj through adj() instead of making new obj inside + grad_obj += self.adj( + farplane=self.propagation.grad( + data, + self.fwd(psi=psi, scan=scan, probe=mode), + intensity, + ), + probe=mode, + scan=scan, + overwrite=True, + ) return grad_obj - def grad_probe(self, data, psi, scan, probe, n=-1, mode=None): - intensity = self._compute_intensity(data, psi, scan, probe, n, mode) + def grad_probe(self, data, psi, scan, probe, t=-1, n=-1, mode=None): + intensity = self._compute_intensity(data, psi, scan, probe, t, n, mode) return self.adj_probe( farplane=self.propagation.grad( data, diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 8ae09181..6ff157a2 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -103,20 +103,23 @@ def update_positions_pd(operator, data, psi, probe, scan, dI = (data - intensity).reshape(*data.shape[:-2], np.prod(data.shape[-2:])) dI_dx, dI_dy = 0, 0 + + #update position use the central wavelength only + i = 1 for m in range(probe.shape[-3]): # step 2: the partial derivatives of wavefront respect to position farplane = operator.fwd(psi=psi, scan=scan, - probe=probe[..., m:m + 1, :, :]) + probe=probe[..., i:i + 1, m:m + 1, :, :]) dfarplane_dx = (farplane - operator.fwd( psi=psi, - probe=probe[..., m:m + 1, :, :], + probe=probe[..., i:i + 1, m:m + 1, :, :], scan=scan + operator.xp.array((0, dx), dtype='float32'), )) / dx dfarplane_dy = (farplane - operator.fwd( psi=psi, - probe=probe[..., m:m + 1, :, :], + probe=probe[..., i:i + 1, m:m + 1, :, :], scan=scan + operator.xp.array((dx, 0), dtype='float32'), )) / dx diff --git a/src/tike/ptycho/probe_MW.py b/src/tike/ptycho/probe_MW.py new file mode 100644 index 00000000..096bf94b --- /dev/null +++ b/src/tike/ptycho/probe_MW.py @@ -0,0 +1,159 @@ +##This is the python version code for initialize probes for multi-wavelength method + +import numpy as np +from scipy.interpolate import RectBivariateSpline + + +def MW_probe(probe_shape, energy, dx_dec, dis_defocus, dis_StoD, **kwargs): + # return probe sorted by the spectrum + # return scale is the wavelength dependent pixel scaling factor + """ + Summary of this function goes here + Parameters: probe_shape -> the matrix size for probe + energy -> number of wavelength for multi-wavelength recontruction + dx_dec -> pixel size on detector + dis_defocus -> defocus distance of sample plane to the focus of the FZP + dis_StoD -> sample to detector distance + kwargs -> setup: 'velo','2idd' + -> spectrum: measured spectrum (if available) [wavelength,intensity] + -> bandwidth + -> lambda0: central wavelength used to generate spectrum if no measured one was given + """ + + if 'spectrum' in kwargs: + spectrum = kwargs.get('spectrum') + spectrum = spectrum[:, ::spectrum.shape[1] // energy][:, :energy] + lambda0 = spectrum[np.argmax(spectrum[:, 1]), 0] + else: + if 'bandwidth' in kwargs: + bandwidth = kwargs.get('bandwidth') + else: + bandwidth = 0.01 + + if 'lambda0' in kwargs: + lambda0 = kwargs.get('lambda0') + else: + lambda0 = 1.24e-9 / 8.8 + spectrum = gaussian_spectrum(lambda0, bandwidth, energy) + + spectrum = spectrum[np.argsort(-spectrum[:, 1])] + + if 'setup' in kwargs: + setup = kwargs.get('setup') + else: + setup = 'default' + probe = np.zeros((energy, 1, probe_shape, probe_shape), dtype=np.complex) + + # pixel size on sample plane (central wavelength) + dx = spectrum[0, 0] * dis_StoD / probe_shape / dx_dec + # focal length for central wavelength + _, _, FL0 = fzp_calculate(spectrum[0, 0], dis_defocus, probe_shape, dx, + setup) + + for i in range(energy): + # get zone plate parameter + T, dx_fzp, _ = fzp_calculate(spectrum[i, 0], dis_defocus, probe_shape, + dx, setup) + nprobe = fresnel_propagation(T, dx_fzp, (FL0 + dis_defocus), + spectrum[i, 0]) + nprobe = nprobe / (np.sqrt(np.sum(np.abs(nprobe)**2))) + probe[i, 0, :, :] = nprobe * (np.sqrt(spectrum[i, 1])) + + return probe[np.newaxis, np.newaxis, np.newaxis] + + +def gaussian_spectrum(lambda0, bandwidth, energy): + spectrum = np.zeros((energy, 2)) + sigma = lambda0 * bandwidth / 2.355 + d_lam = sigma * 4 / (energy - 1) + spectrum[:, 0] = np.arange(-1 * np.floor(energy / 2), np.ceil( + energy / 2)) * d_lam + lambda0 + spectrum[:, 1] = np.exp(-(spectrum[:, 0] - lambda0)**2 / sigma**2) + return spectrum + + +def fzp_calculate(wavelength, dis_defocus, M, dx, setup): + """ + this function can calculate the transfer function of zone plate + return the transfer function, and the pixel sizes + """ + + FZP_para = get_setup(setup) + + FL = 2 * FZP_para['radius'] * FZP_para['outmost'] / wavelength + + # pixel size on FZP plane + dx_fzp = wavelength * (FL + dis_defocus) / M / dx + # coordinate on FZP plane + lx_fzp = -dx_fzp * np.arange(-1 * np.floor(M / 2), np.ceil(M / 2)) + + XX_FZP, YY_FZP = np.meshgrid(lx_fzp, lx_fzp) + # transmission function of FZP + T = np.exp(-1j * 2 * np.pi / wavelength * (XX_FZP**2 + YY_FZP**2) / 2 / FL) + C = np.sqrt(XX_FZP**2 + YY_FZP**2) <= FZP_para['radius'] + H = np.sqrt(XX_FZP**2 + YY_FZP**2) >= FZP_para['CS'] / 2 + + return T * C * H, dx_fzp, FL + + +def get_setup(setup): + switcher = { + 'velo': { + 'radius': 90e-6, + 'outmost': 50e-9, + 'CS': 60e-6 + }, + '2idd': { + 'radius': 80e-6, + 'outmost': 70e-9, + 'CS': 60e-6 + }, + 'default': { + 'radius': 90e-6, + 'outmost': 50e-9, + 'CS': 60e-6 + }, + } + FZP_para = switcher.get(setup) + return FZP_para + + +def fresnel_propagation(input, dxy, z, wavelength): + """ + This is the python version code for fresnel propagation + Summary of this function goes here + Parameters: dx,dy -> the pixel pitch of the object + z -> the distance of the propagation + lambda -> the wave length + X,Y -> meshgrid of coordinate + input -> input object + """ + + (M, N) = input.shape + k = 2 * np.pi / wavelength + # the coordinate grid + M_grid = np.arange(-1 * np.floor(M / 2), np.ceil(M / 2)) + N_grid = np.arange(-1 * np.floor(N / 2), np.ceil(N / 2)) + lx = M_grid * dxy + ly = N_grid * dxy + + XX, YY = np.meshgrid(lx, ly) + + # the coordinate grid on the output plane + fc = 1 / dxy + fu = wavelength * z * fc + lu = M_grid * fu / M + lv = N_grid * fu / N + Fx, Fy = np.meshgrid(lu, lv) + + if z > 0: + pf = np.exp(1j * k * z) * np.exp(1j * k * (Fx**2 + Fy**2) / 2 / z) + kern = input * np.exp(1j * k * (XX**2 + YY**2) / 2 / z) + cgh = np.fft.fft2(np.fft.fftshift(kern)) + OUT = np.fft.fftshift(cgh * np.fft.fftshift(pf)) + else: + pf = np.exp(1j * k * z) * np.exp(1j * k * (XX**2 + YY**2) / 2 / z) + cgh = np.fft.ifft2( + np.fft.fftshift(input * np.exp(1j * k * (Fx**2 + Fy**2) / 2 / z))) + OUT = np.fft.fftshift(cgh) * pf + return OUT diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index ab829673..25079904 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -117,23 +117,25 @@ def simulate( **kwargs, ) as operator: data = 0 - for mode in np.split(probe, probe.shape[-3], axis=-3): - farplane = operator.fwd( - probe=operator.asarray(mode, dtype='complex64'), - scan=operator.asarray(scan, dtype='float32'), - psi=operator.asarray(psi, dtype='complex64'), - **kwargs, - ) - data += np.square( - np.linalg.norm( - farplane.reshape(operator.ntheta, - scan.shape[-2] // operator.fly, -1, - detector_shape, detector_shape), - ord=2, - axis=2, - )) + for energy in np.split(probe, probe.shape[-4], axis=-4): + for mode in np.split(energy, probe.shape[-3], axis=-3): + farplane = operator.fwd( + probe=operator.asarray(mode, dtype='complex64'), + scan=operator.asarray(scan, dtype='float32'), + psi=operator.asarray(psi, dtype='complex64'), + **kwargs, + ) + data += np.square( + np.linalg.norm( + farplane.reshape(operator.ntheta, + scan.shape[-2] // operator.fly, -1, + detector_shape, detector_shape), + ord=2, + axis=2, + )) return operator.asnumpy(data) + def reconstruct( data, probe, scan, diff --git a/src/tike/ptycho/solvers/combined.py b/src/tike/ptycho/solvers/combined.py index 8fc8a3b5..244d5d2f 100644 --- a/src/tike/ptycho/solvers/combined.py +++ b/src/tike/ptycho/solvers/combined.py @@ -31,27 +31,28 @@ def combined( def update_probe(op, data, psi, scan, probe, num_iter=1): """Solve the probe recovery problem.""" # TODO: Cache object patche between mode updates - for m in range(probe.shape[-3]): - - def cost_function(mode): - return op.cost(data, psi, scan, probe, m, mode) - - def grad(mode): - # Use the average gradient for all probe positions - return op.xp.mean( - op.grad_probe(data, psi, scan, probe, m, mode), - axis=(1, 2), - keepdims=True, + for i in range(probe.shape[-4]): + for m in range(probe.shape[-3]): + + def cost_function(mode): + return op.cost(data, psi, scan, probe, i, m, mode) + + def grad(mode): + # Use the average gradient for all probe positions + return op.xp.mean( + op.grad_probe(data, psi, scan, probe, i, m, mode), + axis=(1, 2), + keepdims=True, + ) + + probe[..., i:i + 1, m:m + 1, :, :], cost = conjugate_gradient( + op.xp, + x=probe[..., i:i + 1, m:m + 1, :, :], + cost_function=cost_function, + grad=grad, + num_iter=num_iter, ) - probe[..., m:m + 1, :, :], cost = conjugate_gradient( - op.xp, - x=probe[..., m:m + 1, :, :], - cost_function=cost_function, - grad=grad, - num_iter=num_iter, - ) - logger.info('%10s cost is %+12.5e', 'probe', cost) return probe, cost @@ -74,4 +75,4 @@ def grad(psi): ) logger.info('%10s cost is %+12.5e', 'object', cost) - return psi, cost + return psi, cost \ No newline at end of file diff --git a/tests/data/baboon512.png b/tests/data/baboon512.png new file mode 100644 index 00000000..5646838a Binary files /dev/null and b/tests/data/baboon512.png differ diff --git a/tests/data/lena512.png b/tests/data/lena512.png new file mode 100644 index 00000000..5d6c49fb Binary files /dev/null and b/tests/data/lena512.png differ diff --git a/tests/test_MW.py b/tests/test_MW.py new file mode 100644 index 00000000..21470e33 --- /dev/null +++ b/tests/test_MW.py @@ -0,0 +1,90 @@ +# test multi_energy feature + +import tike.ptycho +from tike.ptycho.probe import add_modes_random_phase +from tike.ptycho.probe_MW import MW_probe + +import matplotlib.pyplot as plt +import h5py +import numpy as np + + +def gen_position(steppix_x, steppix_y, N_scan_x, N_scan_y, probe_shape): + # generate scan position for ptychography + ppx =np.arange(-np.floor(N_scan_x/2.0),np.ceil(N_scan_x/2.0))*steppix_x #x direction, column + ppy =np.arange(-np.floor(N_scan_y/2.0),np.ceil(N_scan_y/2.0))*steppix_y #x row, row + + [ppX, ppY] = np.meshgrid(ppx,ppy) + ppX = np.reshape(ppX, (np.product(ppX.shape),1)) + ppY = np.reshape(ppY, (np.product(ppY.shape),1)) + + ppX = ppX - np.min(ppX) + probe_shape + ppY = ppY - np.min(ppY) + probe_shape + + scan = np.hstack((ppX, ppY)) + scan = scan[np.newaxis] + return scan + +def create_dataset(testdir, energy, nmodes): + dis_StoD = 2 + dis_defocus = 500e-6 + detector_shape = 128 + detector_pixel = 75e-6 + # generate probe + probe,_ = MW_probe(detector_shape,energy,detector_pixel, + dis_defocus,dis_StoD) + + # add nmodes to probe + probe = add_modes_random_phase(probe, nmodes) + + # read simulated object + amplitude = plt.imread(os.path.join( + testdir, "data/baboon512.png")) + phase = plt.imread(os.path.join( + testdir, "data/lena512.png")) + obj = amplitude*np.exp(1j*phase* np.pi) + obj = obj[np.newaxis] + + #generate ptychography scan point + N_scan = 11 #scan point in one direction + lam = 1.24e-9/8.8 # central wavelength + # object plane pixel size + dx = lam*dis_StoD/detector_shape/detector_pixel + step = 300e-9/dx #step size in pixel + scan = gen_position(step, step, N_scan, N_scan,detector_shape) + + # test ptycho.simulate + data = tike.ptycho.simulate( + detector_shape, + probe, scan, + obj, + ) + + return data,obj,probe,scan + + + +if __name__ == "__main__": + + testdir = '/home/beams/YUDONGYAO/code/Tike/tike/tests' + energy = 10 + nmodes = 5 + + #generate simulated dataset + data,obj,probe,scan = create_dataset(testdir,energy,nmodes) + + # ptycho reconstruction + result = tike.ptycho.reconstruct( + data=data, + probe = probe, + scan = scan, + algorithm='combined', + num_iter=10, + nmode=nmodes, + energy = energy, + recover_psi=True, + recover_probe=True, + recover_positions=True, + rtol=-1, + model='poisson', + )