diff --git a/README.md b/README.md index fe55374..642132d 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,18 @@ -# Generative modelling for mass-mapping with fast uncertainty quantification [[arXiv]](https://arxiv.org/abs/2410.24197) +# Generative modelling for fast image reconstruction and uncertainty quantification in astronomical imaging -MMGAN is a novel mass-mapping method based on the regularised conditional generative adversarial network (GAN) framework by [Bendel et al.](https://arxiv.org/abs/2210.13389). Designed to quickly generate approximate posterior samples of the convergence field from shear data, MMGAN offers a fully data-driven approach to mass-mapping. These posterior samples allow for the creation of detailed convergence map reconstructions with associated uncertainty maps, making MMGAN a cutting-edge tool for cosmological analysis. +This repository contains two novel image reconstruction methods based on the regularised conditional generative adversarial network (GAN) framework by [Bendel et al.](https://arxiv.org/abs/2210.13389). These methods are designed to quickly generate approximate posterior samples of the image from a set of noisy data, allowing for the creation of detailed image reconstructions with associated uncertainty maps. The two methods are: + +**1. MMGAN**: *"Generative modelling for mass-mapping with fast uncertainty quantification"* [[arXiv]](https://arxiv.org/abs/2410.24197) + +MMGAN is a novel mass-mapping method designed to quickly generate approximate posterior samples of the convergence field from shear data, MMGAN offers a fully data-driven approach to mass-mapping. These posterior samples allow for the creation of detailed convergence map reconstructions with associated uncertainty maps, making MMGAN a cutting-edge tool for cosmological analysis. ![MMGAN COSMOS convergence map reconstruction](/figures/MMGAN/cosmos_results.png) + +**2. RI-GAN**: *"Generative imaging for radio interferometry with fast uncertainty quantification"* [in prep.] + +RI-GAN is a novel radio interferometric imaging method that combines the regularised conditional GAN framework with model-based updates. This hybrid approach that is both based on the imaging model and data-driven, allows for fast generation of approximate posterior samples using the dirty image and PSF of the observation. This results in a fast imaging method that is robust to varying visibility coverages and which generalises well to unseen data, while providing informative uncertainty maps. + ## Installation After cloning the repository, if in a computing cluster, first run: @@ -26,22 +35,34 @@ pip install -r pypi_requirements.txt ### See ```docs/mass_mapping.md``` for detailed instructions on how to setup and reproduce the results from our paper on [MMGAN](https://arxiv.org/abs/2410.24197). -Alternatively, we have provided a [zenodo file]https://zenodo.org/records/14226221 with the weights of our trained model, as well as a number of simulations. +Alternatively, we have provided a [zenodo file](https://zenodo.org/records/14226221) with the weights of our trained model, as well as a number of simulations. + +Documentation for the RI-GAN method is currently in preparation, but we will provide a similar guide for reproducing the results from our paper on RI-GAN once it is ready. ## Questions and Concerns -If you have any questions, or run into any issues, don't hesitate to reach out at jessica.whitney.22@ucl.ac.uk +If you have any questions, or run into any issues, don't hesitate to reach out at jessica.whitney.22@ucl.ac.uk for the MMGAN method and academic@matthijsmars.com for the RI-GAN method. ## References This repository was forked from [rcGAN](https://github.com/matt-bendel/rcGAN) by [Bendel et al.](https://arxiv.org/abs/2210.13389), with significant changes and modification made by Whitney et al. ## Citation -If you find this code helpful, please cite our paper: -``` -@journal{2024arxiv, - author = {Whitney, Jessica and Liaudat, Tobías and Price, Matthew and Mars, Matthijs and McEwen, Jason}, - title = {Generative modelling for mass-mapping with fast uncertainty quantification}, - year = {2024}, - journal={arXiv:2410.24197} -} -``` \ No newline at end of file +If you find this code helpful, please cite our papers: + +- **MMGAN:** + ``` + @journal{2024arxiv, + author = {Whitney, Jessica and Liaudat, Tobías and Price, Matthew and Mars, Matthijs and McEwen, Jason}, + title = {Generative modelling for mass-mapping with fast uncertainty quantification}, + year = {2024}, + journal={arXiv:2410.24197} + } + ``` +- **RI-GAN:** + ``` + @article{marsGenerativeImagingRadioInterferometry, + author = {Mars, Matthijs and Liaudat, Tobías and Whitney, Jessica and McEwen, Jason}, + title = {Generative imaging for radio interferometry with fast uncertainty quantification}, + year = {}, + journal={in prep.} + } \ No newline at end of file diff --git a/configs/radio_meerkat_macro.yaml b/configs/radio_meerkat_macro.yaml new file mode 100644 index 0000000..21e29f5 --- /dev/null +++ b/configs/radio_meerkat_macro.yaml @@ -0,0 +1,39 @@ +#Change checkpoint and sense_map path +checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/ +data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/ + +# Define the experience +experience: radio + +# Number of code vectors for each phase +num_z_test: 32 +num_z_valid: 8 +num_z_train: 2 + +# Data +in_chans: 2 # Real+Imag parts from obs +out_chans: 1 +im_size: 360 #384x384 pixel images + +# Options +alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling +norm: macro # none, micro, macro + +# Optimizer: +lr: 0.001 +beta_1: 0 +beta_2: 0.99 + +# Loss weights +gp_weight: 10 +adv_weight: 1e-5 + +# Training +batch_size: 2 # per GPU +accumulate_grad_batches: 2 + +#Remember to increase this for full training +num_epochs: 100 +psnr_gain_tol: 0.25 + +num_workers: 4 diff --git a/configs/radio_meerkat_macro_gradient.yaml b/configs/radio_meerkat_macro_gradient.yaml new file mode 100644 index 0000000..2de7443 --- /dev/null +++ b/configs/radio_meerkat_macro_gradient.yaml @@ -0,0 +1,40 @@ +#Change checkpoint and sense_map path +checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/ +data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/ + +# Define the experience +experience: radio + +# Number of code vectors for each phase +num_z_test: 32 +num_z_valid: 8 +num_z_train: 2 + +# Data +in_chans: 2 # Real+Imag parts from obs +out_chans: 1 +im_size: 360 #384x384 pixel images + +# Options +alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling +norm: macro # none, micro, macro +gradient: True + +# Optimizer: +lr: 0.001 +beta_1: 0 +beta_2: 0.99 + +# Loss weights +gp_weight: 10 +adv_weight: 1e-5 + +# Training +batch_size: 2 # per GPU +accumulate_grad_batches: 2 + +#Remember to increase this for full training +num_epochs: 100 +psnr_gain_tol: 0.25 + +num_workers: 4 diff --git a/data/datasets/Radio_data.py b/data/datasets/Radio_data.py index be3b790..cdaf7ac 100644 --- a/data/datasets/Radio_data.py +++ b/data/datasets/Radio_data.py @@ -6,22 +6,37 @@ class RadioDataset_Test(torch.utils.data.Dataset): """Loads the test data.""" - def __init__(self, data_dir, transform): + def __init__(self, data_dir, transform, norm='micro'): """ Args: data_dir (path): The path to the dataset. transform (callable): A callable object (class) that pre-processes the raw data into appropriate form for it to be fed into the model. + norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples) """ self.transform = transform # Collects the paths of all files. # Test/x.npy, Test/y.npy, Test/uv.npy - self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128) - self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128) + self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64) + self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64) self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64) - self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1) - + + if norm == 'none': + self.transform.mean_x, self.transform.std_x = 0, 1 + self.transform.mean_y, self.transform.std_y = 0, 1 + self.transform.mean_uv, self.transform.std_uv = 0, 1 + elif norm == 'micro': + # if micro we do the normalisation in the transform + pass + elif norm == 'macro': + # load means and stds from train set + self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy")) + self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy")) + self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy")) + self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy")) + self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy")) + self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy")) def __len__(self): """Returns the number of samples in the dataset.""" @@ -37,21 +52,38 @@ def __getitem__(self,i): class RadioDataset_Val(torch.utils.data.Dataset): """Loads the test data.""" - def __init__(self, data_dir, transform): + def __init__(self, data_dir, transform, norm='micro'): """ Args: data_dir (path): The path to the dataset. transform (callable): A callable object (class) that pre-processes the raw data into appropriate form for it to be fed into the model. + norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples) """ self.transform = transform # Collects the paths of all files. # Val/x.npy, Val/y.npy, Val/uv.npy - self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128) - self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128) + self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64) + self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64) self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64) - self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1) + + if norm == 'none': + self.transform.mean_x, self.transform.std_x = 0, 1 + self.transform.mean_y, self.transform.std_y = 0, 1 + self.transform.mean_uv, self.transform.std_uv = 0, 1 + elif norm == 'micro': + # if micro we do the normalisation in the transform + pass + elif norm == 'macro': + # load means and stds from train set + self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy")) + self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy")) + self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy")) + self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy")) + self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy")) + self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy")) + def __len__(self): """Returns the number of samples in the dataset.""" @@ -66,22 +98,43 @@ def __getitem__(self,i): class RadioDataset_Train(torch.utils.data.Dataset): """Loads the test data.""" - def __init__(self, data_dir, transform): + def __init__(self, data_dir, transform, norm='micro'): """ Args: data_dir (path): The path to the dataset. transform (callable): A callable object (class) that pre-processes the raw data into appropriate form for it to be fed into the model. + norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples) """ self.transform = transform # Collects the paths of all files. # Train/x.npy, Train/y.npy, Train/uv.npy - self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128) - self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128) + self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64) + self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64) self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64) - self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1) - + + if norm == 'none': + self.transform.mean_x, self.transform.std_x = 0, 1 + self.transform.mean_y, self.transform.std_y = 0, 1 + self.transform.mean_uv, self.transform.std_uv = 0, 1 + elif norm == 'micro': + # if micro we do the normalisation in the transform + pass + elif norm == 'macro': + self.transform.mean_x, self.transform.std_x = self.x.mean(), np.mean(self.x.std(axis=(1,2))) + self.transform.mean_y, self.transform.std_y = self.y.mean(), np.mean(self.y.std(axis=(1,2))) + self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), np.mean(self.uv.std(axis=(1,2))) + + np.save(data_dir.joinpath("mean_x.npy"), self.transform.mean_x) + np.save(data_dir.joinpath("std_x.npy"), self.transform.std_x) + np.save(data_dir.joinpath("mean_y.npy"), self.transform.mean_y) + np.save(data_dir.joinpath("std_y.npy"), self.transform.std_y) + np.save(data_dir.joinpath("mean_uv.npy"), self.transform.mean_uv) + np.save(data_dir.joinpath("std_uv.npy"), self.transform.std_uv) + + + def __len__(self): """Returns the number of samples in the dataset.""" diff --git a/data/lightning/RadioDataModule.py b/data/lightning/RadioDataModule.py index 4de8f5c..27b74c3 100644 --- a/data/lightning/RadioDataModule.py +++ b/data/lightning/RadioDataModule.py @@ -15,6 +15,8 @@ def __init__(self, args, test=False, ISNR=30): self.args = args self.test = test self.ISNR = ISNR + + self.norm = args.__dict__.get('norm', 'micro') def __call__(self, data) -> Tuple[float, float, float, float]: """ Transforms the data. @@ -36,21 +38,28 @@ def __call__(self, data) -> Tuple[float, float, float, float]: x, y, uv = data - + # Format input gt data. - pt_x = transforms.to_tensor(x) # Shape (H, W, 2) + pt_x = transforms.to_tensor(x)[:, :, None] # Shape (H, W, 2) pt_x = pt_x.permute(2, 0, 1) # Shape (2, H, W) # Format observation data. - pt_y = transforms.to_tensor(y) # Shape (H, W, 2) + pt_y = transforms.to_tensor(y)[:, :, None] # Shape (H, W, 2) pt_y = pt_y.permute(2, 0, 1) # Shape (2, H, W) # Format uv data pt_uv = transforms.to_tensor(uv)[:, :, None] # Shape (H, W, 1) pt_uv = pt_uv.permute(2, 0, 1) # Shape (1, H, W) # Normalize everything based on measurements y - normalized_y, mean, std = transforms.normalize_instance(pt_y) - normalized_x = transforms.normalize(pt_x, mean, std) - normalized_uv = transforms.normalize(pt_uv, mean, std) + + if self.norm != 'micro': + normalized_y = transforms.normalize(pt_y, self.mean_y, self.std_y) # scale globally + normalized_x = transforms.normalize(pt_x, self.mean_x, self.std_x) # scale globally + normalized_uv = transforms.normalize(pt_uv, self.mean_uv, self.std_uv) # scale globally + mean, std = self.mean_x, self.std_x + elif self.norm == 'micro': + normalized_y, mean, std = transforms.normalize_instance(pt_y) + normalized_x = transforms.normalize(pt_x, mean, std) # scale based on input + normalized_uv, _, _ = transforms.normalize_instance(pt_uv) # scale on intself # Use normalized stack of y + uv normalized_y = torch.cat([normalized_y, normalized_uv], dim=0) @@ -72,6 +81,7 @@ def __init__(self, args): super().__init__() self.prepare_data_per_node = True self.args = args + self.norm = args.__dict__.get('norm', 'micro') def prepare_data(self): pass @@ -81,17 +91,20 @@ def setup(self, stage: Optional[str] = None): train_data = RadioDataset_Train( data_dir=pathlib.Path(self.args.data_path) / 'train', - transform=RadioDataTransform(self.args, test=False) + transform=RadioDataTransform(self.args, test=False), + norm=self.norm ) dev_data = RadioDataset_Val( data_dir=pathlib.Path(self.args.data_path) / 'val', - transform=RadioDataTransform(self.args, test=True) + transform=RadioDataTransform(self.args, test=True), + norm=self.norm ) test_data = RadioDataset_Test( data_dir=pathlib.Path(self.args.data_path) / 'test', - transform=RadioDataTransform(self.args, test=True) + transform=RadioDataTransform(self.args, test=True), + norm=self.norm ) self.train, self.validate, self.test = train_data, dev_data, test_data diff --git a/evaluation_scripts/radio_cfid/__init__.py b/evaluation_scripts/radio_cfid/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation_scripts/radio_cfid/cfid_metric.py b/evaluation_scripts/radio_cfid/cfid_metric.py new file mode 100644 index 0000000..191898a --- /dev/null +++ b/evaluation_scripts/radio_cfid/cfid_metric.py @@ -0,0 +1,284 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +import numpy as np + +import torchvision.transforms as transforms +from utils.mri.math import tensor_to_complex_np +from tqdm import tqdm + +def symmetric_matrix_square_root_torch(mat, eps=1e-10): + """Compute square root of a symmetric matrix. + Note that this is different from an elementwise square root. We want to + compute M' where M' = sqrt(mat) such that M' * M' = mat. + Also note that this method **only** works for symmetric matrices. + Args: + mat: Matrix to take the square root of. + eps: Small epsilon such that any element less than eps will not be square + rooted to guard against numerical instability. + Returns: + Matrix square root of mat. + """ + # Unlike numpy, tensorflow's return order is (s, u, v) + u, s, v = torch.linalg.svd(mat) + # sqrt is unstable around 0, just use 0 in such case + si = s + si[torch.where(si >= eps)] = torch.sqrt(si[torch.where(si >= eps)]) + + # Note that the v returned by Tensorflow is v = V + # (when referencing the equation A = U S V^T) + # This is unlike Numpy which returns v = V^T + return torch.matmul(torch.matmul(u, torch.diag(si)), v) + + +def trace_sqrt_product_torch(sigma, sigma_v): + """Find the trace of the positive sqrt of product of covariance matrices. + '_symmetric_matrix_square_root' only works for symmetric matrices, so we + cannot just take _symmetric_matrix_square_root(sigma * sigma_v). + ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily). + Let sigma = A A so A = sqrt(sigma), and sigma_v = B B. + We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B)) + Note the following properties: + (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1) + => eigenvalues(A A B B) = eigenvalues (A B B A) + (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2)) + => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A)) + (iii) forall M: trace(M) = sum(eigenvalues(M)) + => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v))) + = sum(sqrt(eigenvalues(A B B A))) + = sum(eigenvalues(sqrt(A B B A))) + = trace(sqrt(A B B A)) + = trace(sqrt(A sigma_v A)) + A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can** + use the _symmetric_matrix_square_root function to find the roots of these + matrices. + Args: + sigma: a square, symmetric, real, positive semi-definite covariance matrix + sigma_v: same as sigma + Returns: + The trace of the positive square root of sigma*sigma_v + """ + + # Note sqrt_sigma is called "A" in the proof above + sqrt_sigma = symmetric_matrix_square_root_torch(sigma) + + # This is sqrt(A sigma_v A) above + sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) + + return torch.trace(symmetric_matrix_square_root_torch(sqrt_a_sigmav_a)) + + +# **Estimators** +# +def sample_covariance_torch(a, b): + ''' + Sample covariance estimating + a = [N,m] + b = [N,m] + ''' + assert (a.shape[0] == b.shape[0]) + assert (a.shape[1] == b.shape[1]) + m = a.shape[1] + N = a.shape[0] + return torch.matmul(torch.transpose(a, 0, 1), b) / N + + +class CFIDMetric: + """Helper function for calculating CFID metric. + + Note: This code is adapted from Facebook's FJD implementation in order to compute + CFID in a streamlined fashion. + + Args: + gan: Model that takes in a conditioning tensor and yields image samples. + reference_loader: DataLoader that yields (images, conditioning) pairs + to be used as the reference distribution. + condition_loader: Dataloader that yields (image, conditioning) pairs. + Images are ignored, and conditions are fed to the GAN. + image_embedding: Function that takes in 4D [B, 3, H, W] image tensor + and yields 2D [B, D] embedding vectors. + condition_embedding: Function that takes in conditioning from + condition_loader and yields 2D [B, D] embedding vectors. + reference_stats_path: File path to save precomputed statistics of + reference distribution. Default: current directory. + save_reference_stats: Boolean indicating whether statistics of + reference distribution should be saved. Default: False. + samples_per_condition: Integer indicating the number of samples to + generate for each condition from the condition_loader. Default: 1. + cuda: Boolean indicating whether to use GPU accelerated FJD or not. + Default: False. + eps: Float value which is added to diagonals of covariance matrices + to improve computational stability. Default: 1e-6. + """ + + def __init__(self, + gan, + loader, + image_embedding, + condition_embedding, + cuda=False, + args=None, + eps=1e-6, + ref_loader=False, + num_samps=1): + + self.gan = gan + self.args = args + self.loader = loader + self.image_embedding = image_embedding + self.condition_embedding = condition_embedding + self.cuda = cuda + self.eps = eps + self.gen_embeds, self.cond_embeds, self.true_embeds = None, None, None + self.num_samps = num_samps + self.ref_loader = ref_loader + self.transforms = torch.nn.Sequential( + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ) + + def _get_embed_im(self, multi_coil_inp, mean, std): + embed_ims = torch.zeros(size=(multi_coil_inp.size(0), 3, self.args.im_size, self.args.im_size)).cuda() + for i in range(multi_coil_inp.size(0)): + reformatted = torch.zeros(size=(1, self.args.im_size, self.args.im_size, 1)).cuda() + reformatted[:, :, :, 0] = multi_coil_inp[i, 0, :, :] + + unnormal_im = reformatted * std[i] + mean[i] + im = unnormal_im[0, :, :, 0] + im = (im - torch.min(im)) / (torch.max(im) - torch.min(im)) + + embed_ims[i, 0, :, :] = im + embed_ims[i, 1, :, :] = im + embed_ims[i, 2, :, :] = im + + return embed_ims + + def _get_generated_distribution(self): + image_embed = [] + cond_embed = [] + true_embed = [] + #cfids = [] + #count = 0 + for i, data in tqdm(enumerate(self.loader), + desc='Computing generated distribution', + total=len(self.loader)): + condition, gt, mean, std = data + condition = condition.cuda() # condition = y + gt = gt.cuda() + mean = mean.cuda() + std = std.cuda() + + with torch.no_grad(): + + for l in range(self.num_samps): + recon = self.gan(condition) + + image = self._get_embed_im(recon, mean, std) + condition_im = self._get_embed_im(condition, mean, std) + true_im = self._get_embed_im(gt, mean, std) + # WARNING -> transform() + img_e = self.image_embedding(self.transforms(image)) + cond_e = self.condition_embedding(self.transforms(condition_im)) + true_e = self.image_embedding(self.transforms(true_im)) + + if self.cuda: + true_embed.append(true_e) + image_embed.append(img_e) + cond_embed.append(cond_e) + else: + true_embed.append(true_e.cpu().numpy()) + image_embed.append(img_e.cpu().numpy()) + cond_embed.append(cond_e.cpu().numpy()) + + if self.ref_loader: + with torch.no_grad(): + for i, data in tqdm(enumerate(self.ref_loader), + desc='Computing generated distribution', + total=len(self.ref_loader)): + condition, gt, mean, std = data + condition = condition.cuda() + gt = gt.cuda() + mean = mean.cuda() + std = std.cuda() + + with torch.no_grad(): + + for l in range(self.num_samps): + recon = self.gan(condition) + + image = self._get_embed_im(recon, mean, std) + condition_im = self._get_embed_im(condition, mean, std) + true_im = self._get_embed_im(gt, mean, std) + + img_e = self.image_embedding(self.transforms(image)) + cond_e = self.condition_embedding(self.transforms(condition_im)) + true_e = self.image_embedding(self.transforms(true_im)) + + if self.cuda: + true_embed.append(true_e) + image_embed.append(img_e) + cond_embed.append(cond_e) + else: + true_embed.append(true_e.cpu().numpy()) + image_embed.append(img_e.cpu().numpy()) + cond_embed.append(cond_e.cpu().numpy()) + + if self.cuda: + true_embed = torch.cat(true_embed, dim=0) + image_embed = torch.cat(image_embed, dim=0) + cond_embed = torch.cat(cond_embed, dim=0) + else: + true_embed = np.concatenate(true_embed, axis=0) + image_embed = np.concatenate(image_embed, axis=0) + cond_embed = np.concatenate(cond_embed, axis=0) + + return image_embed.to(dtype=torch.float64), cond_embed.to(dtype=torch.float64), true_embed.to( + dtype=torch.float64) + + def get_cfid_torch_pinv(self, resample=True, y_predict=None, x_true=None, y_true=None): + if y_true is None: + y_predict, x_true, y_true = self._get_generated_distribution() + + # mean estimations + y_true = y_true.to(x_true.device) + m_y_predict = torch.mean(y_predict, dim=0) + m_x_true = torch.mean(x_true, dim=0) + m_y_true = torch.mean(y_true, dim=0) + + no_m_y_true = y_true - m_y_true + no_m_y_pred = y_predict - m_y_predict + no_m_x_true = x_true - m_x_true + + c_y_predict_x_true = torch.matmul(no_m_y_pred.t(), no_m_x_true) / y_predict.shape[0] + c_y_predict_y_predict = torch.matmul(no_m_y_pred.t(), no_m_y_pred) / y_predict.shape[0] + c_x_true_y_predict = torch.matmul(no_m_x_true.t(), no_m_y_pred) / y_predict.shape[0] + + c_y_true_x_true = torch.matmul(no_m_y_true.t(), no_m_x_true) / y_predict.shape[0] + c_x_true_y_true = torch.matmul(no_m_x_true.t(), no_m_y_true) / y_predict.shape[0] + c_y_true_y_true = torch.matmul(no_m_y_true.t(), no_m_y_true) / y_predict.shape[0] + + inv_c_x_true_x_true = torch.linalg.pinv(torch.matmul(no_m_x_true.t(), no_m_x_true) / y_predict.shape[0]) + + c_y_true_given_x_true = c_y_true_y_true - torch.matmul(c_y_true_x_true, + torch.matmul(inv_c_x_true_x_true, c_x_true_y_true)) + c_y_predict_given_x_true = c_y_predict_y_predict - torch.matmul(c_y_predict_x_true, + torch.matmul(inv_c_x_true_x_true, + c_x_true_y_predict)) + c_y_true_x_true_minus_c_y_predict_x_true = c_y_true_x_true - c_y_predict_x_true + c_x_true_y_true_minus_c_x_true_y_predict = c_x_true_y_true - c_x_true_y_predict + + # Distance between Gaussians + m_dist = torch.einsum('...k,...k->...', m_y_true - m_y_predict, m_y_true - m_y_predict) + c_dist1 = torch.trace( + torch.matmul(torch.matmul(c_y_true_x_true_minus_c_y_predict_x_true, inv_c_x_true_x_true), + c_x_true_y_true_minus_c_x_true_y_predict)) + c_dist_2_1 = torch.trace(c_y_true_given_x_true + c_y_predict_given_x_true) + c_dist_2_2 = - 2 * trace_sqrt_product_torch( + c_y_predict_given_x_true, c_y_true_given_x_true) + + c_dist2 = c_dist_2_1 + c_dist_2_2 + + cfid = m_dist + c_dist1 + c_dist2 + + c_dist = c_dist1 + c_dist2 + + return cfid.cpu().numpy(), m_dist.cpu().numpy(), c_dist.cpu().numpy() diff --git a/find_batch_size.py b/find_batch_size.py index 9e56cab..3e48e89 100644 --- a/find_batch_size.py +++ b/find_batch_size.py @@ -18,6 +18,8 @@ from data.lightning.MassMappingDataModule import MMDataModule from data.lightning.RadioDataModule import RadioDataModule from models.lightning.mmGAN import mmGAN +from models.lightning.riGAN import riGAN +from models.lightning.GriGAN import GriGAN from torch.utils.data import DataLoader def load_object(dct): @@ -49,7 +51,11 @@ def load_object(dct): model = mmGAN(cfg, args.exp_name, args.num_gpus) elif cfg.experience == 'radio': DM = RadioDataModule - model = mmGAN(cfg, args.exp_name, args.num_gpus) + if cfg.__dict__.get("gradient", False): + model = GriGAN(cfg, args.exp_name, args.num_gpus) + else: + model = riGAN(cfg, args.exp_name, args.num_gpus) + else: print("No valid experience selected in config file. Options are 'mri', 'mass_mapping', 'radio'.") exit() diff --git a/models/archs/radio/__init__.py b/models/archs/radio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/archs/radio/discriminator.py b/models/archs/radio/discriminator.py new file mode 100644 index 0000000..3e405e3 --- /dev/null +++ b/models/archs/radio/discriminator.py @@ -0,0 +1,124 @@ +#Mass Map discriminator + +import torch +from torch import nn + + +class ResidualBlock(nn.Module): + """ + A Convolutional Block that consists of two convolution layers each followed by + instance normalization, relu activation and dropout. + """ + + def __init__(self, in_chans, out_chans, batch_norm=True): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.batch_norm = batch_norm + + if self.in_chans != self.out_chans: + self.out_chans = self.in_chans + + self.conv_1_x_1 = nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(1, 1)) + self.layers = nn.Sequential( + nn.LeakyReLU(negative_slope=0.2), + nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(3, 3), padding=1), + nn.LeakyReLU(negative_slope=0.2), + nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(3, 3), padding=1), + ) + + def forward(self, input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + output = input + + return self.layers(output) + self.conv_1_x_1(output) + + +class FullDownBlock(nn.Module): + def __init__(self, in_chans, out_chans): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + self.in_chans = in_chans + self.out_chans = out_chans + + self.downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=(2, 2), stride=2), + nn.Conv2d(self.in_chans, self.out_chans, kernel_size=(3, 3), padding=1), + nn.InstanceNorm2d(self.out_chans), + nn.LeakyReLU(negative_slope=0.2), + ) + + def forward(self, input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + return self.downsample(input) + + def __repr__(self): + return f'AvgPool(in_chans={self.in_chans}, out_chans={self.out_chans}\nResBlock(in_chans={self.out_chans}, out_chans={self.out_chans}' + + +class DiscriminatorModel(nn.Module): + def __init__(self, in_chans, out_chans, input_im_size): + """ + Args: + in_chans (int): Number of channels in the input to the U-Net model. + out_chans (int): Number of channels in the output to the U-Net model. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.input_im_size = input_im_size + + self.initial_layers = nn.Sequential( + nn.Conv2d(self.in_chans, 32, kernel_size=(3, 3), padding=1), # 384x384 + nn.LeakyReLU() + ) + + # This should be refactored to adapt to input and output number of features and the resolution dimensions + self.encoder_layers = nn.ModuleList() + self.encoder_layers += [FullDownBlock(32, 64)] + self.encoder_layers += [FullDownBlock(64, 128)] + self.encoder_layers += [FullDownBlock(128, 256)] + self.encoder_layers += [FullDownBlock(256, 512)] + self.encoder_layers += [FullDownBlock(512, 512)] + self.encoder_layers += [FullDownBlock(512, 512)] + + downsampled_imsize = self.input_im_size + for i in range(6): + downsampled_imsize = downsampled_imsize // 2 # half dimension (rounded down) for every FullDownBlock + + self.dense = nn.Sequential( + nn.Flatten(), + nn.Linear(512 * downsampled_imsize**2 , 1), + ) + + def forward(self, input, y): + output = torch.cat([input, y], dim=1) + output = self.initial_layers(output) + # Apply down-sampling layers + for layer in self.encoder_layers: + output = layer(output) + return self.dense(output) diff --git a/models/archs/radio/generator.py b/models/archs/radio/generator.py new file mode 100644 index 0000000..58f97f4 --- /dev/null +++ b/models/archs/radio/generator.py @@ -0,0 +1,235 @@ +#Mass Map Generator + +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +import torchvision.transforms as transforms + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(in_features), + nn.PReLU(), + nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(in_features), + nn.PReLU() + ) + self.conv_1x1 = nn.Conv2d(in_features, in_features, kernel_size=1) + + def forward(self, x): + return self.conv_1x1(x) + self.conv_block(x) + + +class ConvDownBlock(nn.Module): + def __init__(self, in_chans, out_chans, batch_norm=True): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.batch_norm = batch_norm + + self.conv_1 = nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1) + self.res = ResidualBlock(out_chans) + self.conv_3 = nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, stride=2) + self.bn = nn.BatchNorm2d(out_chans) + self.activation = nn.PReLU() + + def forward(self, input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + if self.batch_norm: + out = self.activation(self.bn(self.conv_1(input))) + skip_out = self.res(out) + out = self.conv_3(skip_out) + else: + out = self.activation(self.conv_1(input)) + skip_out = self.res(out) + out = self.conv_3(skip_out) + + return out, skip_out + + +class ConvUpBlock(nn.Module): + def __init__(self, in_chans, out_chans): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.conv_1 = nn.ConvTranspose2d(in_chans // 2, in_chans // 2, kernel_size=3, padding=1, stride=2) + self.bn = nn.BatchNorm2d(in_chans // 2) + self.activation = nn.PReLU() + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1), + nn.BatchNorm2d(out_chans), + nn.PReLU(), + ResidualBlock(out_chans), + ) + + def forward(self, input, skip_input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + residual_skip = skip_input + upsampled = self.activation(self.bn(self.conv_1(input, output_size=residual_skip.size()))) + concat_tensor = torch.cat([residual_skip, upsampled], dim=1) + + return self.layers(concat_tensor) + + +class ConvUpBlock_alt_upsample(nn.Module): + def __init__(self, in_chans, out_chans): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.conv_1 = nn.Conv2d(in_chans // 2, in_chans//2, kernel_size=3, padding=1) + + self.bn = nn.BatchNorm2d(in_chans // 2) + self.activation = nn.PReLU() + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1), + nn.BatchNorm2d(out_chans), + nn.PReLU(), + ResidualBlock(out_chans), + ) + + def forward(self, input, skip_input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + + residual_skip = skip_input + resized = torch.nn.functional.interpolate(input, size=skip_input.size()[2:], mode='nearest') + upsampled = self.activation(self.bn(self.conv_1(resized))) + concat_tensor = torch.cat([residual_skip, upsampled], dim=1) + + return self.layers(concat_tensor) + + +class UNetModel(nn.Module): + def __init__(self, in_chans, out_chans, alt_upsample=False, chans=128, num_pool_layers=4): + """ + Args: + in_chans (int): Number of channels in the input to the U-Net model. + out_chans (int): Number of channels in the output to the U-Net model. + chans (int): Number of output channels of the first convolution layer. + num_pool_layers (int): Number of down-sampling and up-sampling layers. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.chans = chans + self.num_pool_layers = num_pool_layers + + self.down_sample_layers = nn.ModuleList([ConvDownBlock(in_chans, self.chans, batch_norm=False)]) + for i in range(self.num_pool_layers - 1): + if i < 3: + self.down_sample_layers += [ConvDownBlock(self.chans, self.chans * 2)] + self.chans *= 2 + else: + self.down_sample_layers += [ConvDownBlock(self.chans, self.chans)] + + self.res_layer_1 = nn.Sequential( + nn.Conv2d(self.chans, self.chans, kernel_size=3, padding=1), + nn.BatchNorm2d(self.chans), + nn.PReLU(), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ) + + self.conv = nn.Sequential( + nn.Conv2d(self.chans, self.chans, kernel_size=3, padding=1), + nn.BatchNorm2d(self.chans), + nn.PReLU(), + ) + + self.up_sample_layers = nn.ModuleList() + + + for i in range(self.num_pool_layers - 1): + if not alt_upsample: + self.up_sample_layers += [ConvUpBlock(self.chans * 2, self.chans // 2)] + else: + self.up_sample_layers += [ConvUpBlock_alt_upsample(self.chans * 2, self.chans // 2)] + + self.chans //= 2 + + if not alt_upsample: + self.up_sample_layers += [ConvUpBlock(self.chans * 2, self.chans)] + else: + self.up_sample_layers += [ConvUpBlock_alt_upsample(self.chans * 2, self.chans)] + + self.conv2 = nn.Sequential( + nn.Conv2d(self.chans, self.chans // 2, kernel_size=1), + nn.Conv2d(self.chans // 2, out_chans, kernel_size=1), + ) + + def forward(self, input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + output = input + stack = [] + # Apply down-sampling layers + for layer in self.down_sample_layers: + output, skip_out = layer(output) + stack.append(skip_out) + + output = self.conv(output) + + # Apply up-sampling layers + for layer in self.up_sample_layers: + output = layer(output, stack.pop()) + + return self.conv2(output) diff --git a/models/archs/radio/gradient_generator.py b/models/archs/radio/gradient_generator.py new file mode 100644 index 0000000..d4e0cff --- /dev/null +++ b/models/archs/radio/gradient_generator.py @@ -0,0 +1,322 @@ +#Mass Map Generator + +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch import nn +import torchvision.transforms as transforms + +class GradBlock(nn.Module): + def __init__(self, in_chans, out_chans, batch_norm=False): + """ + Gradient block that calculates the gradient of the L2 norm of the input image with respect to the dirty image and appends this to the input. + + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.batch_norm = batch_norm + + self.conv_1 = nn.Conv2d(3, out_chans, kernel_size=3, padding=1) + + self.fft = lambda x: torch.fft.fft2(torch.fft.fftshift(x, dim=(-2, -1))) + self.ifft = lambda x: torch.fft.ifftshift(torch.fft.ifft2(x), dim=(-2,-1)) + + + @staticmethod + def fft(x): + x = x.to(torch.complex64) + torch.fft.fft2(x, norm='ortho', out=x) + return x + + def forward(self, x, x_i, dirty_i, psf_i): + ft_psf = self.fft(psf_i) + m = torch.real(self.ifft(self.fft(x_i) * ft_psf)) + + x_i = torch.cat((x_i, m, dirty_i), dim=1) + x_i = self.conv_1(x_i) + + return torch.cat((x, x_i), dim=1) # concat info to original channels + + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(in_features), + nn.PReLU(), + nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(in_features), + nn.PReLU() + ) + self.conv_1x1 = nn.Conv2d(in_features, in_features, kernel_size=1) + + def forward(self, x): + return self.conv_1x1(x) + self.conv_block(x) + + +class ConvDownBlock(nn.Module): + def __init__(self, in_chans, out_chans, batch_norm=True): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.batch_norm = batch_norm + + self.conv_1 = nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1) + self.res = ResidualBlock(out_chans) + self.conv_3 = nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, stride=2) + self.bn = nn.BatchNorm2d(out_chans) + self.activation = nn.PReLU() + + def forward(self, input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + if self.batch_norm: + out = self.activation(self.bn(self.conv_1(input))) + skip_out = self.res(out) + out = self.conv_3(skip_out) + else: + out = self.activation(self.conv_1(input)) + skip_out = self.res(out) + out = self.conv_3(skip_out) + + return out, skip_out + + +class ConvUpBlock(nn.Module): + def __init__(self, in_chans, out_chans, skip_chans): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + print(in_chans, skip_chans, out_chans) + self.conv_1 = nn.ConvTranspose2d(in_chans, in_chans // 2 , kernel_size=3, padding=1, stride=2) + self.bn = nn.BatchNorm2d(in_chans // 2) + self.activation = nn.PReLU() + + self.layers = nn.Sequential( + nn.Conv2d(in_chans//2 + skip_chans*2 , out_chans, kernel_size=3, padding=1), # *2 comes from skip channels + gradient channels + nn.BatchNorm2d(out_chans), + nn.PReLU(), + ResidualBlock(out_chans), + ) + + def forward(self, input, skip_input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + residual_skip = skip_input + upsampled = self.activation(self.bn(self.conv_1(input, output_size=residual_skip.size()))) + concat_tensor = torch.cat([residual_skip, upsampled], dim=1) + + return self.layers(concat_tensor) + + def forward_part_1(self, input, skip_input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + residual_skip = skip_input + upsampled = self.activation(self.bn(self.conv_1(input, output_size=residual_skip.size()))) + + return upsampled + + def forward_part_2(self, residual_skip, upsampled): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + concat_tensor = torch.cat([residual_skip, upsampled], dim=1) + + return self.layers(concat_tensor) + + +class ConvUpBlock_alt_upsample(nn.Module): + def __init__(self, in_chans, out_chans): + """ + Args: + in_chans (int): Number of channels in the input. + out_chans (int): Number of channels in the output. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.conv_1 = nn.Conv2d(in_chans // 2, in_chans//2, kernel_size=3, padding=1) + + self.bn = nn.BatchNorm2d(in_chans // 2) + self.activation = nn.PReLU() + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1), + nn.BatchNorm2d(out_chans), + nn.PReLU(), + ResidualBlock(out_chans), + ) + + def forward(self, input, skip_input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + + residual_skip = skip_input + resized = torch.nn.functional.interpolate(input, size=skip_input.size()[2:], mode='nearest') + upsampled = self.activation(self.bn(self.conv_1(resized))) + concat_tensor = torch.cat([residual_skip, upsampled], dim=1) + + return self.layers(concat_tensor) + + +class UNetModel(nn.Module): + def __init__(self, in_chans, out_chans, alt_upsample=False, chans=128, num_pool_layers=4): + """ + Args: + in_chans (int): Number of channels in the input to the U-Net model. + out_chans (int): Number of channels in the output to the U-Net model. + chans (int): Number of output channels of the first convolution layer. + num_pool_layers (int): Number of down-sampling and up-sampling layers. + """ + super().__init__() + + + self.in_chans = in_chans + self.out_chans = out_chans + self.chans = chans + self.num_pool_layers = num_pool_layers + + # create down_sampled dirty image and psf + self.ds_layer = nn.AvgPool2d((2,2), ceil_mode=True) + self.us_layer = nn.AvgPool2d((2,2), ceil_mode=False) + + + self.grad_layers = nn.ModuleList([GradBlock(in_chans, self.chans, batch_norm=False)]) + self.down_sample_layers = nn.ModuleList([ConvDownBlock(self.chans + in_chans, self.chans*2, batch_norm=False)]) + self.chans *= 2 + for i in range(self.num_pool_layers): + + self.grad_layers.append(GradBlock(self.chans, self.chans)) + self.chans *= 2 + if i < self.num_pool_layers - 1: + self.down_sample_layers += [ConvDownBlock(self.chans, self.chans)] + + self.res_layer_1 = nn.Sequential( + nn.Conv2d(self.chans, self.chans, kernel_size=3, padding=1), + nn.BatchNorm2d(self.chans), + nn.PReLU(), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ResidualBlock(self.chans), + ) + + self.conv = nn.Sequential( + nn.Conv2d(self.chans, self.chans, kernel_size=3, padding=1), + nn.BatchNorm2d(self.chans), + nn.PReLU(), + ) + + # from the lowest level gradient + self.up_sample_layers = nn.ModuleList() + + + for i in range(self.num_pool_layers - 1): + if not alt_upsample: + self.up_sample_layers += [ConvUpBlock(self.chans, self.chans // 2, self.chans // 2)] + else: + self.up_sample_layers += [ConvUpBlock_alt_upsample(self.chans * 2, self.chans // 2)] + + self.chans //= 2 + + if not alt_upsample: + self.up_sample_layers += [ConvUpBlock(self.chans, self.chans // 2, self.chans // 2)] + else: + self.up_sample_layers += [ConvUpBlock_alt_upsample(self.chans * 2, self.chans)] + self.chans //= 2 + + self.conv2 = nn.Sequential( + nn.Conv2d(self.chans, self.chans // 2, kernel_size=1), + nn.Conv2d(self.chans // 2, out_chans, kernel_size=1), + ) + + def forward(self, input): + """ + Args: + input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] + + Returns: + (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] + """ + + output = input + stack = [] + + downsampled_input = [output] + up_sampled_input = [output] + for i in range(self.num_pool_layers): + downsampled_input.append(self.ds_layer(downsampled_input[-1])) + up_sampled_input.append(self.us_layer(up_sampled_input[-1])) + + for i, layer in enumerate(self.down_sample_layers): + output = self.grad_layers[i](output, output[:,:1], downsampled_input[i][:,:1], downsampled_input[i][:,1:2]) + output, skip_out = layer(output) + + stack.append(skip_out) + + i+=1 + output = self.grad_layers[i](output,output[:,:1], downsampled_input[i][:,:1], downsampled_input[i][:,1:2]) + output = self.conv(output) + + # Apply up-sampling layers + for i, layer in enumerate(self.up_sample_layers): + output = layer.forward_part_1(output, stack[-1]) + output = self.grad_layers[-(i+1)](output, output[:,:1], up_sampled_input[-(i+2)][:,:1], up_sampled_input[-(i+2)][:,1:2]) + output = layer.forward_part_2(output, stack.pop()) + + return self.conv2(output) diff --git a/models/lightning/GriGAN.py b/models/lightning/GriGAN.py new file mode 100644 index 0000000..50e7c02 --- /dev/null +++ b/models/lightning/GriGAN.py @@ -0,0 +1,283 @@ +#Mass Mapping + +import torch + +import pytorch_lightning as pl +import numpy as np +import torch.autograd as autograd +from matplotlib import cm + +from PIL import Image +from torch.nn import functional as F +from models.archs.radio.gradient_generator import UNetModel + +from models.archs.radio.discriminator import DiscriminatorModel +from evaluation_scripts.metrics import psnr +from torchmetrics.functional import peak_signal_noise_ratio + +class GriGAN(pl.LightningModule): + def __init__(self, args, exp_name, num_gpus): + super().__init__() + self.args = args # This is the cfg object + self.exp_name = exp_name + self.num_gpus = num_gpus + + self.in_chans = args.in_chans + 2 # Two extra dimensions of the added noise + self.out_chans = args.out_chans + + try: + alt_upsample = self.args.alt_upsample + except: + alt_upsample = False + + self.generator = UNetModel( + in_chans=self.in_chans, + out_chans=self.out_chans, + chans = 64, # half the original as otherwise it is doubled because of the gradients + alt_upsample=alt_upsample + ) + + self.discriminator = DiscriminatorModel( + in_chans=self.args.in_chans + self.args.out_chans, # Number of channels from x and y + out_chans=self.out_chans, + input_im_size=self.args.im_size + ) + + self.std_mult = 1 + self.is_good_model = 0 + self.resolution = self.args.im_size + + self.save_hyperparameters() # Save passed values + + def get_noise(self, num_vectors): + z = torch.randn(num_vectors, 2, self.resolution, self.resolution, device=self.device) + return z + + def reformat(self, samples): + reformatted_tensor = torch.swapaxes(torch.clone(samples), 3, 1) + return reformatted_tensor + + def readd_measures(self, samples, measures): + return torch.clone(samples) + + def compute_gradient_penalty(self, real_samples, fake_samples, y): + """Calculates the gradient penalty loss for WGAN GP""" + Tensor = torch.FloatTensor + # Random weight term for interpolation between real and fake samples + alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(self.device) + # Get random interpolation between real and fake samples + interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) + d_interpolates = self.discriminator(input=interpolates, y=y) + fake = Tensor(real_samples.shape[0], 1).fill_(1.0).to( + self.device) + + # Get gradient w.r.t. interpolates + gradients = autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + def forward(self, y): + num_vectors = y.size(0) + noise = self.get_noise(num_vectors) + samples = self.generator(torch.cat([y, noise], dim=1)) + samples = self.readd_measures(samples, y) + return samples + + def adversarial_loss_discriminator(self, fake_pred, real_pred): + return fake_pred.mean() - real_pred.mean() + + def adversarial_loss_generator(self, y, gens): + fake_pred = torch.zeros(size=(y.shape[0], self.args.num_z_train), device=self.device) + for k in range(y.shape[0]): + cond = torch.zeros( + 1, + self.args.in_chans, + self.args.im_size, + self.args.im_size, + device=self.device + ) + cond[0, :, :, :] = y[k, :, :, :] + cond = cond.repeat(self.args.num_z_train, 1, 1, 1) + temp = self.discriminator(input=gens[k], y=cond) + fake_pred[k] = temp[:, 0] + + gen_pred_loss = torch.mean(fake_pred[0]) + for k in range(y.shape[0] - 1): + gen_pred_loss += torch.mean(fake_pred[k + 1]) + + adv_weight = 1e-5 + if self.current_epoch <= 4: + adv_weight = 1e-2 + elif self.current_epoch <= 22: + adv_weight = 1e-4 + + return - adv_weight * gen_pred_loss.mean() + + def l1_std_p(self, avg_recon, gens, x): + return F.l1_loss(avg_recon, x) - self.std_mult * np.sqrt( + 2 / (np.pi * self.args.num_z_train * (self.args.num_z_train+ 1)) + ) * torch.std(gens, dim=1).mean() + + def gradient_penalty(self, x_hat, x, y): + gradient_penalty = self.compute_gradient_penalty(x.data, x_hat.data, y.data) + + return self.args.gp_weight * gradient_penalty + + def drift_penalty(self, real_pred): + return 0.001 * torch.mean(real_pred ** 2) + + def training_step(self, batch, batch_idx, optimizer_idx): + y, x, mean, std = batch + + # train generator + if optimizer_idx == 1: + gens = torch.zeros( + size=( + y.size(0), + self.args.num_z_train, + self.args.out_chans, + self.args.im_size, + self.args.im_size + ), + device=self.device) + for z in range(self.args.num_z_train): + gens[:, z, :, :, :] = self.forward(y) + + avg_recon = torch.mean(gens, dim=1) + + # adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss_generator(y, gens) + g_loss += self.l1_std_p(avg_recon, gens, x) + + self.log('g_loss', g_loss, prog_bar=True) + + return g_loss + + # train discriminator + if optimizer_idx == 0: + x_hat = self.forward(y) + + real_pred = self.discriminator(input=x, y=y) + fake_pred = self.discriminator(input=x_hat, y=y) + + d_loss = self.adversarial_loss_discriminator(fake_pred, real_pred) + d_loss += self.gradient_penalty(x_hat, x, y) + d_loss += self.drift_penalty(real_pred) + + self.log('d_loss', d_loss, prog_bar=True) + + return d_loss + + def validation_step(self, batch, batch_idx, external_test=False): + y, x, mean, std= batch + + fig_count = 0 + + if external_test: + num_code = self.args.num_z_test + else: + num_code = self.args.num_z_valid + + gens = torch.zeros(size=(y.size(0), num_code, self.args.out_chans, self.args.im_size, self.args.im_size), + device=self.device) + for z in range(num_code): + gens[:, z, :, :, :] = self.forward(y) * std[:, None, None, None] + mean[:, None, None, None] + + avg = torch.mean(gens, dim=1) + avg_gen = self.reformat(avg) + gt = self.reformat(x * std[:, None, None, None] + mean[:, None, None, None]) + + mag_avg_list = [] + mag_single_list = [] + mag_gt_list = [] + psnr_8s = [] + psnr_1s = [] + + for j in range(y.size(0)): + psnr_8s.append(peak_signal_noise_ratio(avg_gen[j], gt[j])) + psnr_1s.append(peak_signal_noise_ratio(self.reformat(gens[:, 0])[j], gt[j])) + + mag_avg_list.append(avg_gen[None, j, :, :, :]) + mag_single_list.append(self.reformat(gens[:, 0])[j]) + mag_gt_list.append(gt[None, j, :, :, :]) + + psnr_8s = torch.stack(psnr_8s) + psnr_1s = torch.stack(psnr_1s) + mag_avg_gen = torch.cat(mag_avg_list, dim=0) + mag_single_gen = torch.cat(mag_single_list, dim=0) + mag_gt = torch.cat(mag_gt_list, dim=0) + + self.log('psnr_8_step', psnr_8s.mean(), on_step=True, on_epoch=False, prog_bar=True) + self.log('psnr_1_step', psnr_1s.mean(), on_step=True, on_epoch=False, prog_bar=True) + + if batch_idx == 0: + if self.global_rank == 0 and self.current_epoch % 1 == 0 and fig_count == 0: + fig_count += 1 + # Using single generation instead of avg generator (mag_avg_gen) + avg_gen_np = mag_avg_gen[0, :, :, 0].cpu().numpy() + gt_np = mag_gt[0, :, :, 0].cpu().numpy() + + plot_avg_np = (avg_gen_np - np.min(avg_gen_np)) / (np.max(avg_gen_np) - np.min(avg_gen_np)) + plot_gt_np = (gt_np - np.min(gt_np)) / (np.max(gt_np) - np.min(gt_np)) + + np_psnr = psnr(gt_np, avg_gen_np) + + self.logger.log_image( + key=f"epoch_{self.current_epoch}_img", + images=[ + Image.fromarray(np.uint8(plot_avg_np*255), 'L'), + ], + caption=[f"Recon: PSNR (NP): {np_psnr:.2f}"] + ) + + self.trainer.strategy.barrier() + + return {'psnr_8': psnr_8s.mean(), 'psnr_1': psnr_1s.mean()} + + def validation_epoch_end(self, validation_step_outputs): + avg_psnr = self.all_gather(torch.stack([x['psnr_8'] for x in validation_step_outputs]).mean()).mean() + avg_single_psnr = self.all_gather(torch.stack([x['psnr_1'] for x in validation_step_outputs]).mean()).mean() + + avg_psnr = avg_psnr.cpu().numpy() + avg_single_psnr = avg_single_psnr.cpu().numpy() + + psnr_diff = (avg_single_psnr + 2.5) - avg_psnr + + mu_0 = 2e-2 + self.std_mult += mu_0 * psnr_diff + + if np.abs(psnr_diff) <= self.args.psnr_gain_tol: + self.is_good_model = 1 + else: + self.is_good_model = 0 + + self.trainer.strategy.barrier() + + def configure_optimizers(self): + opt_g = torch.optim.Adam( + self.generator.parameters(), + lr=self.args.lr, + betas=(self.args.beta_1, self.args.beta_2) + ) + opt_d = torch.optim.Adam( + self.discriminator.parameters(), + lr=self.args.lr, + betas=(self.args.beta_1, self.args.beta_2) + ) + return [opt_d, opt_g], [] + + def on_save_checkpoint(self, checkpoint): + checkpoint["beta_std"] = self.std_mult + checkpoint["is_valid"] = self.is_good_model + + def on_load_checkpoint(self, checkpoint): + self.std_mult = checkpoint["beta_std"] + self.is_good_model = checkpoint["is_valid"] diff --git a/models/lightning/riGAN.py b/models/lightning/riGAN.py new file mode 100644 index 0000000..de24ee8 --- /dev/null +++ b/models/lightning/riGAN.py @@ -0,0 +1,284 @@ +#Mass Mapping + +import torch + +import pytorch_lightning as pl +import numpy as np +import torch.autograd as autograd +from matplotlib import cm + +from PIL import Image +from torch.nn import functional as F +from utils.mri.fftc import ifft2c_new, fft2c_new #TODO: Unused imports. +from models.archs.radio.generator import UNetModel +from models.archs.radio.discriminator import DiscriminatorModel +from evaluation_scripts.metrics import psnr +from torchmetrics.functional import peak_signal_noise_ratio + +class riGAN(pl.LightningModule): + def __init__(self, args, exp_name, num_gpus): + super().__init__() + self.args = args # This is the cfg object + self.exp_name = exp_name + self.num_gpus = num_gpus + + self.in_chans = args.in_chans + 2 # Two extra dimensions of the added noise + self.out_chans = args.out_chans + + try: + alt_upsample = self.args.alt_upsample + except: + alt_upsample = False + + self.generator = UNetModel( + in_chans=self.in_chans, + out_chans=self.out_chans, + alt_upsample=alt_upsample + ) + + self.discriminator = DiscriminatorModel( + in_chans=self.args.in_chans + self.args.out_chans, # Number of channels from x and y + out_chans=self.out_chans, + input_im_size=self.args.im_size + ) + + self.std_mult = 1 + self.is_good_model = 0 + self.resolution = self.args.im_size + + self.save_hyperparameters() # Save passed values + + def get_noise(self, num_vectors): + z = torch.randn(num_vectors, 2, self.resolution, self.resolution, device=self.device) + return z + + def reformat(self, samples): + reformatted_tensor = torch.swapaxes(torch.clone(samples), 3, 1) + return reformatted_tensor + + def readd_measures(self, samples, measures): + return torch.clone(samples) + + def compute_gradient_penalty(self, real_samples, fake_samples, y): + """Calculates the gradient penalty loss for WGAN GP""" + Tensor = torch.FloatTensor + # Random weight term for interpolation between real and fake samples + alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(self.device) + # Get random interpolation between real and fake samples + interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) + d_interpolates = self.discriminator(input=interpolates, y=y) + fake = Tensor(real_samples.shape[0], 1).fill_(1.0).to( + self.device) + + # Get gradient w.r.t. interpolates + gradients = autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + def forward(self, y): + num_vectors = y.size(0) + noise = self.get_noise(num_vectors) + samples = self.generator(torch.cat([y, noise], dim=1)) + samples = self.readd_measures(samples, y) + return samples + + def adversarial_loss_discriminator(self, fake_pred, real_pred): + return fake_pred.mean() - real_pred.mean() + + def adversarial_loss_generator(self, y, gens): + fake_pred = torch.zeros(size=(y.shape[0], self.args.num_z_train), device=self.device) + for k in range(y.shape[0]): + cond = torch.zeros( + 1, + self.args.in_chans, + self.args.im_size, + self.args.im_size, + device=self.device + ) + cond[0, :, :, :] = y[k, :, :, :] + cond = cond.repeat(self.args.num_z_train, 1, 1, 1) + temp = self.discriminator(input=gens[k], y=cond) + fake_pred[k] = temp[:, 0] + + gen_pred_loss = torch.mean(fake_pred[0]) + for k in range(y.shape[0] - 1): + gen_pred_loss += torch.mean(fake_pred[k + 1]) + + adv_weight = 1e-5 + if self.current_epoch <= 4: + adv_weight = 1e-2 + elif self.current_epoch <= 22: + adv_weight = 1e-4 + + return - adv_weight * gen_pred_loss.mean() + + def l1_std_p(self, avg_recon, gens, x): + return F.l1_loss(avg_recon, x) - self.std_mult * np.sqrt( + 2 / (np.pi * self.args.num_z_train * (self.args.num_z_train+ 1)) + ) * torch.std(gens, dim=1).mean() + + def gradient_penalty(self, x_hat, x, y): + gradient_penalty = self.compute_gradient_penalty(x.data, x_hat.data, y.data) + + return self.args.gp_weight * gradient_penalty + + def drift_penalty(self, real_pred): + return 0.001 * torch.mean(real_pred ** 2) + + def training_step(self, batch, batch_idx, optimizer_idx): + y, x, mean, std = batch + + # train generator + if optimizer_idx == 1: + gens = torch.zeros( + size=( + y.size(0), + self.args.num_z_train, + self.args.out_chans, + self.args.im_size, + self.args.im_size + ), + device=self.device) + for z in range(self.args.num_z_train): + gens[:, z, :, :, :] = self.forward(y) + + avg_recon = torch.mean(gens, dim=1) + + # adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss_generator(y, gens) + g_loss += self.l1_std_p(avg_recon, gens, x) + + self.log('g_loss', g_loss, prog_bar=True) + + return g_loss + + # train discriminator + if optimizer_idx == 0: + x_hat = self.forward(y) + + real_pred = self.discriminator(input=x, y=y) + fake_pred = self.discriminator(input=x_hat, y=y) + + d_loss = self.adversarial_loss_discriminator(fake_pred, real_pred) + d_loss += self.gradient_penalty(x_hat, x, y) + d_loss += self.drift_penalty(real_pred) + + self.log('d_loss', d_loss, prog_bar=True) + + return d_loss + + def validation_step(self, batch, batch_idx, external_test=False): + y, x, mean, std= batch + + fig_count = 0 + + if external_test: + num_code = self.args.num_z_test + else: + num_code = self.args.num_z_valid + + gens = torch.zeros(size=(y.size(0), num_code, self.args.out_chans, self.args.im_size, self.args.im_size), + device=self.device) + for z in range(num_code): + gens[:, z, :, :, :] = self.forward(y) * std[:, None, None, None] + mean[:, None, None, None] + + avg = torch.mean(gens, dim=1) + avg_gen = self.reformat(avg) + gt = self.reformat(x * std[:, None, None, None] + mean[:, None, None, None]) + + mag_avg_list = [] + mag_single_list = [] + mag_gt_list = [] + psnr_8s = [] + psnr_1s = [] + + for j in range(y.size(0)): + psnr_8s.append(peak_signal_noise_ratio(avg_gen[j], gt[j])) + psnr_1s.append(peak_signal_noise_ratio(self.reformat(gens[:, 0])[j], gt[j])) + + mag_avg_list.append(avg_gen[None, j, :, :, :]) + mag_single_list.append(self.reformat(gens[:, 0])[j]) + mag_gt_list.append(gt[None, j, :, :, :]) + + psnr_8s = torch.stack(psnr_8s) + psnr_1s = torch.stack(psnr_1s) + mag_avg_gen = torch.cat(mag_avg_list, dim=0) + mag_single_gen = torch.cat(mag_single_list, dim=0) + mag_gt = torch.cat(mag_gt_list, dim=0) + + self.log('psnr_8_step', psnr_8s.mean(), on_step=True, on_epoch=False, prog_bar=True) + self.log('psnr_1_step', psnr_1s.mean(), on_step=True, on_epoch=False, prog_bar=True) + + if batch_idx == 0: + if self.global_rank == 0 and self.current_epoch % 1 == 0 and fig_count == 0: + fig_count += 1 + # Using single generation instead of avg generator (mag_avg_gen) + avg_gen_np = mag_avg_gen[0, :, :, 0].cpu().numpy() + gt_np = mag_gt[0, :, :, 0].cpu().numpy() + + plot_avg_np = (avg_gen_np - np.min(avg_gen_np)) / (np.max(avg_gen_np) - np.min(avg_gen_np)) + plot_gt_np = (gt_np - np.min(gt_np)) / (np.max(gt_np) - np.min(gt_np)) + + np_psnr = psnr(gt_np, avg_gen_np) + + self.logger.log_image( + key=f"epoch_{self.current_epoch}_img", + images=[ + Image.fromarray(np.uint8(plot_gt_np*255), 'L'), + Image.fromarray(np.uint8(plot_avg_np*255), 'L'), + Image.fromarray(np.uint8(cm.jet(5*np.abs(plot_gt_np - plot_avg_np))*255)) + ], + caption=["GT", f"Recon: PSNR (NP): {np_psnr:.2f}", "Error"] + ) + + self.trainer.strategy.barrier() + + return {'psnr_8': psnr_8s.mean(), 'psnr_1': psnr_1s.mean()} + + def validation_epoch_end(self, validation_step_outputs): + avg_psnr = self.all_gather(torch.stack([x['psnr_8'] for x in validation_step_outputs]).mean()).mean() + avg_single_psnr = self.all_gather(torch.stack([x['psnr_1'] for x in validation_step_outputs]).mean()).mean() + + avg_psnr = avg_psnr.cpu().numpy() + avg_single_psnr = avg_single_psnr.cpu().numpy() + + psnr_diff = (avg_single_psnr + 2.5) - avg_psnr + + mu_0 = 2e-2 + self.std_mult += mu_0 * psnr_diff + + if np.abs(psnr_diff) <= self.args.psnr_gain_tol: + self.is_good_model = 1 + else: + self.is_good_model = 0 + + self.trainer.strategy.barrier() + + def configure_optimizers(self): + opt_g = torch.optim.Adam( + self.generator.parameters(), + lr=self.args.lr, + betas=(self.args.beta_1, self.args.beta_2) + ) + opt_d = torch.optim.Adam( + self.discriminator.parameters(), + lr=self.args.lr, + betas=(self.args.beta_1, self.args.beta_2) + ) + return [opt_d, opt_g], [] + + def on_save_checkpoint(self, checkpoint): + checkpoint["beta_std"] = self.std_mult + checkpoint["is_valid"] = self.is_good_model + + def on_load_checkpoint(self, checkpoint): + self.std_mult = checkpoint["beta_std"] + self.is_good_model = checkpoint["is_valid"] diff --git a/scripts/radio/plot.py b/scripts/radio/plot.py index fc5216d..8be0e20 100644 --- a/scripts/radio/plot.py +++ b/scripts/radio/plot.py @@ -4,21 +4,24 @@ import json import numpy as np -import matplotlib.patches as patches import sys -sys.path.append('/home/jjwhit/rcGAN/') +sys.path.append('/home/mars/git/rcGAN/') + print(sys.path) -from data.lightning.MassMappingDataModule import MMDataModule +from data.lightning.RadioDataModule import RadioDataModule from utils.parse_args import create_arg_parser from pytorch_lightning import seed_everything -from models.lightning.mmGAN import mmGAN +from models.lightning.riGAN import riGAN +from models.lightning.GriGAN import GriGAN + from utils.mri.math import tensor_to_complex_np -import matplotlib.pyplot as plt -from matplotlib import gridspec -from scipy import ndimage import sys +from datetime import date +import pickle +import os + def load_object(dct): return types.SimpleNamespace(**dct) @@ -35,246 +38,69 @@ def load_object(dct): cfg = yaml.load(f, Loader=yaml.FullLoader) cfg = json.loads(json.dumps(cfg), object_hook=load_object) - dm = MMDataModule(cfg) - fig_count = 1 + dm = RadioDataModule(cfg) + fig_count = 5 dm.setup() + train_loader = dm.train_dataloader() test_loader = dm.test_dataloader() + today = date.today() + pred_dir = cfg.data_path + "/pred/" + os.makedirs(pred_dir, exist_ok=True) + + with torch.no_grad(): - mmGAN_model = mmGAN.load_from_checkpoint( - checkpoint_path=cfg.checkpoint_dir + args.exp_name + '/checkpoint_best.ckpt') - - mmGAN_model.cuda() - - mmGAN_model.eval() - - for i, data in enumerate(test_loader): + if cfg.__dict__.get("gradient", False): + RIGAN_model = GriGAN.load_from_checkpoint( + checkpoint_path=cfg.checkpoint_dir + args.exp_name + '/checkpoint_best.ckpt') + else: + RIGAN_model = riGAN.load_from_checkpoint( + checkpoint_path=cfg.checkpoint_dir + args.exp_name + '/checkpoint_best.ckpt') + + RIGAN_model.cuda() + + RIGAN_model.eval() + + true = [] + dirty = [] + pred = [] + +# cfg.num_z_test = 100 + + for i, data in enumerate(train_loader): + print(f"{i}/{len(test_loader)}") y, x, mean, std = data y = y.cuda() x = x.cuda() mean = mean.cuda() std = std.cuda() - # gens_mmGAN = torch.zeros( - # size=(y.size(0), cfg.num_z_test, cfg.in_chans // 2, cfg.im_size, cfg.im_size, 2)).cuda() - gens_mmGAN = torch.zeros(size=(y.size(0), cfg.num_z_test, cfg.im_size, cfg.im_size, 2)).cuda() + gens_RIGAN = torch.zeros(size=(y.size(0), cfg.num_z_test, cfg.im_size, cfg.im_size, 1)).cuda() for z in range(cfg.num_z_test): - gens_mmGAN[:, z, :, :, :] = mmGAN_model.reformat(mmGAN_model.forward(y)) + gens_RIGAN[:, z, :, :, :] = RIGAN_model.reformat(RIGAN_model.forward(y)) - avg_mmGAN = torch.mean(gens_mmGAN, dim=1) + + avg_RIGAN = torch.mean(gens_RIGAN, dim=1) - gt = mmGAN_model.reformat(x) - zfr = mmGAN_model.reformat(y) + gt = RIGAN_model.reformat(x) + zfr = RIGAN_model.reformat(y) + + tensor_to_complex_np = lambda x: x for j in range(y.size(0)): - np_avgs = { - 'mmGAN': None, - } - - np_samps = { - 'mmGAN': [], - } - - np_stds = { - 'mmGAN': None, - } - - np_gt = None - - # S = sp.linop.Multiply((cfg.im_size, cfg.im_size), tensor_to_complex_np(maps[j].cpu())) - - np_gt = ndimage.rotate( - torch.tensor(tensor_to_complex_np((gt[j] * std[j] + mean[j]).cpu())).abs().numpy(), 180) - np_zfr = ndimage.rotate( - torch.tensor(tensor_to_complex_np((zfr[j] * std[j] + mean[j]).cpu())).abs().numpy(), 180) - - np_avgs['mmGAN'] = ndimage.rotate( - torch.tensor(tensor_to_complex_np((avg_mmGAN[j] * std[j] + mean[j]).cpu())).abs().numpy(), - 180) - - for z in range(cfg.num_z_test): - np_samps['mmGAN'].append(ndimage.rotate(torch.tensor( - tensor_to_complex_np((gens_mmGAN[j, z] * std[j] + mean[j]).cpu())).abs().numpy(), 180)) - - np_stds['mmGAN'] = np.std(np.stack(np_samps['mmGAN']), axis=0) - - method = 'mmGAN' - zoom_startx = np.random.randint(120, 250) - zoom_starty1 = np.random.randint(30, 80) - zoom_starty2 = np.random.randint(260, 300) - - p = np.random.rand() - zoom_starty = zoom_starty1 - if p <= 0.5: - zoom_starty = zoom_starty2 - - zoom_length = 80 - - x_coord = zoom_startx + zoom_length - y_coords = [zoom_starty, zoom_starty + zoom_length] - - # Global recon, error, std - nrow = 1 - ncol = 4 - - fig = plt.figure(figsize=(ncol + 1, nrow + 1)) - - gs = gridspec.GridSpec(nrow, ncol, - wspace=0.0, hspace=0.0, - top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1), - left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)) - - ax = plt.subplot(gs[0, 0]) - ax.imshow(np_gt, cmap='inferno', vmin=0, vmax=0.5* np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title("Truth") - - ax = plt.subplot(gs[0, 1]) - ax.imshow(np_avgs[method], cmap='inferno', vmin=0, vmax=0.5** np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title("Reconstruction") - - - ax = plt.subplot(gs[0, 2]) - im = ax.imshow(2 * np.abs(np_avgs[method] - np_gt), cmap='jet', vmin=0, - vmax=0.5*np.max(np.abs(np_avgs['mmGAN'] - np_gt))) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title("Error") - - - ax = plt.subplot(gs[0, 3]) - ax.imshow(np_stds[method], cmap='viridis', vmin=0, vmax=0.5*np.max(np_stds['mmGAN'])) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title("Std. Dev.") - - plt.savefig(f'/share/gpu0/jjwhit/test_figures_1/test_fig_avg_err_std_{fig_count}.png', bbox_inches='tight', dpi=300) - plt.close(fig) - - nrow = 1 - ncol = 8 - - fig = plt.figure(figsize=(ncol + 1, nrow + 1)) - - gs = gridspec.GridSpec(nrow, ncol, - wspace=0.0, hspace=0.0, - top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1), - left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)) - - ax = plt.subplot(gs[0, 0]) - ax.imshow(np_gt, cmap='inferno', vmin=0, vmax=0.5* np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title('Truth') - - ax1 = ax - - rect = patches.Rectangle((zoom_startx, zoom_starty), zoom_length, zoom_length, linewidth=1, - edgecolor='r', - facecolor='none') - - # Add the patch to the Axes - ax.add_patch(rect) - - ax = plt.subplot(gs[0, 1]) - ax.imshow(np_gt[zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length], - cmap='inferno', - vmin=0, vmax=0.5 * np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title('Truth') - - connection_path_1 = patches.ConnectionPatch([zoom_startx + zoom_length, zoom_starty], - [0, 0], coordsA=ax1.transData, - coordsB=ax.transData, color='r') - fig.add_artist(connection_path_1) - connection_path_2 = patches.ConnectionPatch([zoom_startx + zoom_length, zoom_starty + zoom_length], [0, zoom_length], - coordsA=ax1.transData, - coordsB=ax.transData, color='r') - fig.add_artist(connection_path_2) - - ax = plt.subplot(gs[0, 2]) - ax.imshow( - np_avgs[method][zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length], - cmap='inferno', vmin=0, vmax=0.5 * np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title('32-Avg.') - - ax = plt.subplot(gs[0, 3]) - avg = np.zeros((384, 384)) - for l in range(4): - avg += np_samps[method][l] - - avg = avg / 8 - - ax.imshow( - avg[zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length], - cmap='inferno', vmin=0, vmax=0.5 * np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title('8-Avg.') - - ax = plt.subplot(gs[0, 4]) - avg = np.zeros((384, 384)) - for l in range(2): - avg += np_samps[method][l] - - avg = avg / 4 - ax.imshow( - avg[zoom_starty:zoom_starty + zoom_length, zoom_startx:zoom_startx + zoom_length], - cmap='inferno', vmin=0, vmax=0.5 * np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title('4-Avg.') - - for samp in range(2): - ax = plt.subplot(gs[0, samp + 5]) - ax.imshow(np_samps[method][samp][zoom_starty:zoom_starty + zoom_length, - zoom_startx:zoom_startx + zoom_length], cmap='inferno', vmin=0, - vmax=0.5** np.max(np_gt)) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title(f'Sample {samp + 1}') - - - ax = plt.subplot(gs[0, 7]) - ax.imshow(np_stds[method][zoom_starty:zoom_starty + zoom_length, - zoom_startx:zoom_startx + zoom_length], cmap='viridis', vmin=0, - vmax=0.5**np.max(np_stds['mmGAN'])) - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title('Std. Dev.') - - plt.savefig(f'/share/gpu0/jjwhit/test_figures_1/zoomed_avg_samps_{fig_count}.png', bbox_inches='tight', dpi=300) - plt.close(fig) - - if fig_count == args.num_figs: - sys.exit() - fig_count += 1 + + true.append( torch.tensor(tensor_to_complex_np((gt[j] * std[j] + mean[j]).cpu())).numpy().real) + dirty.append( torch.tensor(tensor_to_complex_np((zfr[j] * std[j] + mean[j]).cpu())).numpy().real ) + pred.append(torch.tensor(tensor_to_complex_np((gens_RIGAN[j] * std[j] + mean[j]).cpu())).numpy().real ) + + if len(true) == 10: + # save a small set after 10 predictions + pickle.dump([np.array(true), np.array(dirty), np.array(pred)], open(f"{pred_dir}/pred_train_{args.exp_name}_{today}_small.pkl", "wb")) + + if len(true) >= 1900: + break + pickle.dump([np.array(true), np.array(dirty), np.array(pred)], open(f"{pred_dir}/pred_train_{args.exp_name}_{today}.pkl", "wb")) + + + diff --git a/scripts/radio/plot_30dor.py b/scripts/radio/plot_30dor.py new file mode 100644 index 0000000..bcc38c7 --- /dev/null +++ b/scripts/radio/plot_30dor.py @@ -0,0 +1,107 @@ +import torch +import yaml +import types +import json + +import numpy as np +import matplotlib.patches as patches + +import sys +sys.path.append('/home/mars/git/rcGAN/') + +from data.lightning.RadioDataModule import RadioDataModule +from utils.parse_args import create_arg_parser +from pytorch_lightning import seed_everything +from models.lightning.riGAN import riGAN +from models.lightning.GriGAN import GriGAN + +from utils.mri.math import tensor_to_complex_np +import sys +from datetime import date +import pickle +import os + + +def load_object(dct): + return types.SimpleNamespace(**dct) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision('medium') + args = create_arg_parser().parse_args() + seed_everything(1, workers=True) + + config_path = args.config + + with open(config_path, 'r') as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + cfg = json.loads(json.dumps(cfg), object_hook=load_object) + + dm = RadioDataModule(cfg) + fig_count = 5 + dm.setup() + test_loader = dm.test_dataloader() + + today = date.today() + pred_dir = cfg.data_path + "/pred/" + os.makedirs(pred_dir, exist_ok=True) + + + with torch.no_grad(): + if cfg.__dict__.get("gradient", False): + RIGAN_model = GriGAN.load_from_checkpoint( + checkpoint_path=cfg.checkpoint_dir + args.exp_name + '/checkpoint_best.ckpt') + else: + RIGAN_model = riGAN.load_from_checkpoint( + checkpoint_path=cfg.checkpoint_dir + args.exp_name + '/checkpoint_best.ckpt') + + RIGAN_model.cuda() + + RIGAN_model.eval() + + true = [] + dirty = [] + pred = [] + +# cfg.num_z_test = 100 + + for i, data in enumerate([0]): + print(f"{i}/{len([0])}") + im, d, p, _, _ = np.load('/home/mars/src_aiai/notebooks/GAN_30Dor.npy') + + pt_y, pt_x, mean, std = dm.test.transform((im, d, p)) + + y = torch.tensor(pt_y[None,:]).cuda() + x = torch.tensor(pt_x[None, :]).cuda() + mean = torch.tensor(np.array([mean])).cuda() + std = torch.tensor(np.array([std])).cuda() + + gens_RIGAN = torch.zeros(size=(y.size(0), cfg.num_z_test, cfg.im_size, cfg.im_size, 1)).cuda() + + for z in range(cfg.num_z_test): + gens_RIGAN[:, z, :, :, :] = RIGAN_model.reformat(RIGAN_model.forward(y)) + + + avg_RIGAN = torch.mean(gens_RIGAN, dim=1) + + gt = RIGAN_model.reformat(x) + zfr = RIGAN_model.reformat(y) + + + tensor_to_complex_np = lambda x: x + for j in range(y.size(0)): + + true.append( torch.tensor(tensor_to_complex_np((gt[j] * std[j] + mean[j]).cpu())).numpy().real) + dirty.append( torch.tensor(tensor_to_complex_np((zfr[j] * std[j] + mean[j]).cpu())).numpy().real ) + pred.append(torch.tensor(tensor_to_complex_np((gens_RIGAN[j] * std[j] + mean[j]).cpu())).numpy().real ) + + pickle.dump([np.array(true), np.array(dirty), np.array(pred)], open(f"{pred_dir}/pred_30DOR_{args.exp_name}_{today}.pkl", "wb")) + + exit() + + + + + + + diff --git a/scripts/radio/test.py b/scripts/radio/test.py deleted file mode 100644 index 84b5dd4..0000000 --- a/scripts/radio/test.py +++ /dev/null @@ -1,190 +0,0 @@ -import torch -import yaml -import types -import json -import lpips - -import numpy as np - -import sys -sys.path.append('/home/jjwhit/rcGAN/') -print(sys.path) - -from data.lightning.MassMappingDataModule import MMDataModule -from utils.parse_args import create_arg_parser -from pytorch_lightning import seed_everything -from models.lightning.mmGAN import mmGAN -from utils.mri.math import tensor_to_complex_np -from evaluation_scripts.metrics import psnr, ssim -from utils.embeddings import VGG16Embedding -from evaluation_scripts.mass_map_cfid.cfid_metric import CFIDMetric -from DISTS_pytorch import DISTS - - - -def load_object(dct): - return types.SimpleNamespace(**dct) - - -def rgb(im, im_size, unit_norm=False): - """ - Args: - im: Input image. - im_size (int): Width of (square) image. - """ - embed_ims = torch.zeros(size=(3, im_size, im_size)) - tens_im = torch.tensor(im) - - if unit_norm: - tens_im = (tens_im - torch.min(tens_im)) / (torch.max(tens_im) - torch.min(tens_im)) - else: - tens_im = 2 * (tens_im - torch.min(tens_im)) / (torch.max(tens_im) - torch.min(tens_im)) - 1 - - embed_ims[0, :, :] = tens_im - embed_ims[1, :, :] = tens_im - embed_ims[2, :, :] = tens_im - - return embed_ims.unsqueeze(0) - - -if __name__ == "__main__": - torch.set_float32_matmul_precision('medium') - args = create_arg_parser().parse_args() - seed_everything(1, workers=True) - - config_path = args.config - - with open(config_path, 'r') as f: - cfg = yaml.load(f, Loader=yaml.FullLoader) - cfg = json.loads(json.dumps(cfg), object_hook=load_object) - - cfg.batch_size = cfg.batch_size * 4 - dm = MMDataModule(cfg) - - - dm.setup() - - train_dataloader = dm.train_dataloader() - val_dataloader = dm.val_dataloader() - test_loader = dm.test_dataloader() - - lpips_met = lpips.LPIPS(net='alex') - dists_met = DISTS() - - with torch.no_grad(): - model = mmGAN.load_from_checkpoint( - checkpoint_path=cfg.checkpoint_dir + args.exp_name + '/checkpoint_best.ckpt') - model.cuda() - model.eval() - - n_samps = [1, 2, 4, 8, 16, 32] - - for n in n_samps: - print(f"\n\n{n} SAMPLES") - psnrs = [] - ssims = [] - apsds = [] - lpipss = [] - distss = [] - - for i, data in enumerate(test_loader): - y, x, mean, std = data - y = y.cuda() - x = x.cuda() - mean = mean.cuda() - std = std.cuda() - - # gens = torch.zeros(size=(y.size(0), n, cfg.in_chans // 2, cfg.im_size, cfg.im_size, 2)).cuda() - # for z in range(n): - # gens[:, z, :, :, :, :] = model.reformat(model.forward(y)) - gens = torch.zeros(size=(y.size(0), n, cfg.im_size, cfg.im_size, 2)).cuda() - for z in range(n): - gens[:, z, :, :, :] = model.reformat(model.forward(y)) - - avg = torch.mean(gens, dim=1) - - gt = model.reformat(x) - - for j in range(y.size(0)): - single_samps = np.zeros((n, cfg.im_size, cfg.im_size)) - - # S = sp.linop.Multiply((cfg.im_size, cfg.im_size), tensor_to_complex_np(maps[j].cpu())) - gt_ksp, avg_ksp = tensor_to_complex_np((gt[j] * std[j] + mean[j]).cpu()), tensor_to_complex_np( - (avg[j] * std[j] + mean[j]).cpu()) - - avg_gen_np = torch.tensor(avg_ksp).abs().numpy() - gt_np = torch.tensor(gt_ksp).abs().numpy() - - for z in range(n): - # np_samp = tensor_to_complex_np((gens[j, z, :, :, :, :] * std[j] + mean[j]).cpu()) - np_samp = tensor_to_complex_np((gens[j, z, :, :, :] * std[j] + mean[j]).cpu()) - single_samps[z, :, :] = torch.tensor(np_samp).abs().numpy() - - med_np = np.median(single_samps, axis=0) - - apsds.append(np.mean(np.std(single_samps, axis=0), axis=(0, 1))) - psnrs.append(psnr(gt_np, avg_gen_np)) - ssims.append(ssim(gt_np, avg_gen_np)) - lpipss.append(lpips_met(rgb(gt_np, cfg.im_size), rgb(avg_gen_np, cfg.im_size)).numpy()) - distss.append(dists_met(rgb(gt_np, cfg.im_size, unit_norm=True), rgb(avg_gen_np, cfg.im_size, unit_norm=True)).numpy()) - - print('AVG Recon') - print(f'PSNR: {np.mean(psnrs):.2f} \pm {np.std(psnrs) / np.sqrt(len(psnrs)):.2f}') - print(f'SSIM: {np.mean(ssims):.4f} \pm {np.std(ssims) / np.sqrt(len(ssims)):.4f}') - print(f'LPIPS: {np.mean(lpipss):.4f} \pm {np.std(lpipss) / np.sqrt(len(lpipss)):.4f}') - print(f'DISTS: {np.mean(distss):.4f} \pm {np.std(distss) / np.sqrt(len(distss)):.4f}') - print(f'APSD: {np.mean(apsds):.1f}') - - cfids = [] - m_comps = [] - c_comps = [] - - inception_embedding = VGG16Embedding(parallel=True) - # CFID_1 - cfid_metric = CFIDMetric(gan=model, - loader=test_loader, - image_embedding=inception_embedding, - condition_embedding=inception_embedding, - cuda=True, - args=cfg, - ref_loader=False, - num_samps=32) - - cfid, m_comp, c_comp = cfid_metric.get_cfid_torch_pinv() - cfids.append(cfid) - m_comps.append(m_comp) - c_comps.append(c_comp) - - # CFID_2 - cfid_metric = CFIDMetric(gan=model, - loader=val_dataloader, - image_embedding=inception_embedding, - condition_embedding=inception_embedding, - cuda=True, - args=cfg, - ref_loader=False, - num_samps=8) - - cfid, m_comp, c_comp = cfid_metric.get_cfid_torch_pinv() - cfids.append(cfid) - m_comps.append(m_comp) - c_comps.append(c_comp) - - # CFID_3 - cfid_metric = CFIDMetric(gan=model, - loader=val_dataloader, - image_embedding=inception_embedding, - condition_embedding=inception_embedding, - cuda=True, - args=cfg, - ref_loader=train_dataloader, - num_samps=1) - - cfid, m_comp, c_comp = cfid_metric.get_cfid_torch_pinv() - cfids.append(cfid) - m_comps.append(m_comp) - c_comps.append(c_comp) - - print("\n\n") - for l in range(3): - print(f'CFID_{l+1}: {cfids[l]:.2f}; M_COMP: {m_comps[l]:.4f}; C_COMP: {c_comps[l]:.4f}') diff --git a/scripts/radio/validate.py b/scripts/radio/validate.py index d244fca..04eb202 100644 --- a/scripts/radio/validate.py +++ b/scripts/radio/validate.py @@ -5,12 +5,18 @@ import json import numpy as np +import sys +sys.path.append('/home/mars/git/rcGAN/') +from data.lightning.RadioDataModule import RadioDataModule from data.lightning.MassMappingDataModule import MMDataModule from utils.parse_args import create_arg_parser -from models.lightning.mmGAN import mmGAN +from models.lightning.riGAN import riGAN +from models.lightning.GriGAN import GriGAN + from pytorch_lightning import seed_everything from utils.embeddings import VGG16Embedding -from evaluation_scripts.mass_map_cfid.cfid_metric import CFIDMetric +from evaluation_scripts.radio_cfid.cfid_metric import CFIDMetric + def load_object(dct): return types.SimpleNamespace(**dct) @@ -26,21 +32,31 @@ def load_object(dct): cfg = yaml.load(f, Loader=yaml.FullLoader) cfg = json.loads(json.dumps(cfg), object_hook=load_object) - dm = MMDataModule(cfg) + if cfg.experience == 'radio': + dm = RadioDataModule(cfg) + elif cfg.experience == 'mass_mapping': + dm = MMDataModule(cfg) + else: + exit("no data for specified experience") + dm.setup() +# train_loader = dm.train_dataloader() val_loader = dm.val_dataloader() best_epoch = -1 inception_embedding = VGG16Embedding() best_cfid = 10000000 start_epoch = 50 #Will start saving models after 50 epochs - + end_epoch = 100 with torch.no_grad(): for epoch in range(end_epoch): print(f"VALIDATING EPOCH: {epoch + 1}") try: - model = mmGAN.load_from_checkpoint(checkpoint_path=cfg.checkpoint_dir + args.exp_name + f'/checkpoint-epoch={epoch}.ckpt') + if cfg.__dict__.get("gradient", False): + model = GriGAN.load_from_checkpoint(checkpoint_path=cfg.checkpoint_dir + args.exp_name + f'/checkpoint-epoch={epoch}.ckpt') + else: + model = riGAN.load_from_checkpoint(checkpoint_path=cfg.checkpoint_dir + args.exp_name + f'/checkpoint-epoch={epoch}.ckpt') except Exception as e: print(e) continue @@ -73,12 +89,12 @@ def load_object(dct): print(f"BEST EPOCH: {best_epoch}") - for epoch in range(end_epoch): - try: - if epoch != best_epoch: - os.remove(cfg.checkpoint_dir + args.exp_name + f'/checkpoint-epoch={epoch}.ckpt') - except: - pass +# for epoch in range(end_epoch): +# try: +# if epoch != best_epoch: +# os.remove(cfg.checkpoint_dir + args.exp_name + f'/checkpoint-epoch={epoch}.ckpt') +# except: +# pass os.rename( cfg.checkpoint_dir + args.exp_name + f'/checkpoint-epoch={best_epoch}.ckpt', diff --git a/train.py b/train.py index 546cea3..3800852 100644 --- a/train.py +++ b/train.py @@ -2,11 +2,14 @@ import yaml import types import json +import sys +import os +sys.path.append('/home/mars/git/rcGAN/') import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint -from data.lightning.MRIDataModule import MRIDataModule +# from data.lightning.MRIDataModule import MRIDataModule from utils.parse_args import create_arg_parser from models.lightning.rcGAN import rcGAN from pytorch_lightning import seed_everything @@ -14,6 +17,9 @@ from data.lightning.MassMappingDataModule import MMDataModule from data.lightning.RadioDataModule import RadioDataModule from models.lightning.mmGAN import mmGAN +from models.lightning.riGAN import riGAN +from models.lightning.GriGAN import GriGAN + def load_object(dct): return types.SimpleNamespace(**dct) @@ -43,8 +49,12 @@ def load_object(dct): dm = MMDataModule(cfg) model = mmGAN(cfg, args.exp_name, args.num_gpus) elif cfg.experience == 'radio': + cfg.num_workers = args.num_gpus # set number of workers to same as gpu dm = RadioDataModule(cfg) - model = mmGAN(cfg, args.exp_name, args.num_gpus) + if cfg.__dict__.get("gradient", False): + model = GriGAN(cfg, args.exp_name, args.num_gpus) + else: + model = riGAN(cfg, args.exp_name, args.num_gpus) else: print("No valid experience selected in config file. Options are 'mri', 'mass_mapping', 'radio'.") exit() @@ -52,16 +62,18 @@ def load_object(dct): wandb_logger = WandbLogger( project=cfg.experience, name=args.exp_name, - log_model="True", + log_model=True, save_dir=cfg.checkpoint_dir + 'wandb' ) + + os.makedirs(cfg.checkpoint_dir + args.exp_name + '/', exist_ok=True) checkpoint_callback_epoch = ModelCheckpoint( monitor='epoch', mode='max', dirpath=cfg.checkpoint_dir + args.exp_name + '/', filename='checkpoint-{epoch}', - # every_n_epochs=1, + every_n_epochs=1, save_top_k=20 ) @@ -88,4 +100,4 @@ def load_object(dct): trainer.fit(model, dm, ckpt_path=cfg.checkpoint_dir + args.exp_name + f'/checkpoint-epoch={args.resume_epoch}.ckpt') else: - trainer.fit(model, dm) + trainer.fit(model, dm) \ No newline at end of file