Skip to content

Dev radio #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
# Generative modelling for mass-mapping with fast uncertainty quantification [[arXiv]](https://arxiv.org/abs/2410.24197)
# Generative modelling for fast image reconstruction and uncertainty quantification in astronomical imaging

MMGAN is a novel mass-mapping method based on the regularised conditional generative adversarial network (GAN) framework by [Bendel et al.](https://arxiv.org/abs/2210.13389). Designed to quickly generate approximate posterior samples of the convergence field from shear data, MMGAN offers a fully data-driven approach to mass-mapping. These posterior samples allow for the creation of detailed convergence map reconstructions with associated uncertainty maps, making MMGAN a cutting-edge tool for cosmological analysis.
This repository contains two novel image reconstruction methods based on the regularised conditional generative adversarial network (GAN) framework by [Bendel et al.](https://arxiv.org/abs/2210.13389). These methods are designed to quickly generate approximate posterior samples of the image from a set of noisy data, allowing for the creation of detailed image reconstructions with associated uncertainty maps. The two methods are:

**1. MMGAN**: *"Generative modelling for mass-mapping with fast uncertainty quantification"* [[arXiv]](https://arxiv.org/abs/2410.24197)

MMGAN is a novel mass-mapping method designed to quickly generate approximate posterior samples of the convergence field from shear data, MMGAN offers a fully data-driven approach to mass-mapping. These posterior samples allow for the creation of detailed convergence map reconstructions with associated uncertainty maps, making MMGAN a cutting-edge tool for cosmological analysis.

![MMGAN COSMOS convergence map reconstruction](/figures/MMGAN/cosmos_results.png)


**2. RI-GAN**: *"Generative imaging for radio interferometry with fast uncertainty quantification"* [in prep.]

RI-GAN is a novel radio interferometric imaging method that combines the regularised conditional GAN framework with model-based updates. This hybrid approach that is both based on the imaging model and data-driven, allows for fast generation of approximate posterior samples using the dirty image and PSF of the observation. This results in a fast imaging method that is robust to varying visibility coverages and which generalises well to unseen data, while providing informative uncertainty maps.

## Installation

After cloning the repository, if in a computing cluster, first run:
Expand All @@ -26,22 +35,34 @@ pip install -r pypi_requirements.txt
###
See ```docs/mass_mapping.md``` for detailed instructions on how to setup and reproduce the results from our paper on [MMGAN](https://arxiv.org/abs/2410.24197).

Alternatively, we have provided a [zenodo file]https://zenodo.org/records/14226221 with the weights of our trained model, as well as a number of simulations.
Alternatively, we have provided a [zenodo file](https://zenodo.org/records/14226221) with the weights of our trained model, as well as a number of simulations.

Documentation for the RI-GAN method is currently in preparation, but we will provide a similar guide for reproducing the results from our paper on RI-GAN once it is ready.

## Questions and Concerns
If you have any questions, or run into any issues, don't hesitate to reach out at [email protected]
If you have any questions, or run into any issues, don't hesitate to reach out at [email protected] for the MMGAN method and [email protected] for the RI-GAN method.

## References
This repository was forked from [rcGAN](https://github.com/matt-bendel/rcGAN) by [Bendel et al.](https://arxiv.org/abs/2210.13389), with significant changes and modification made by Whitney et al.


## Citation
If you find this code helpful, please cite our paper:
```
@journal{2024arxiv,
author = {Whitney, Jessica and Liaudat, Tobías and Price, Matthew and Mars, Matthijs and McEwen, Jason},
title = {Generative modelling for mass-mapping with fast uncertainty quantification},
year = {2024},
journal={arXiv:2410.24197}
}
```
If you find this code helpful, please cite our papers:

- **MMGAN:**
```
@journal{2024arxiv,
author = {Whitney, Jessica and Liaudat, Tobías and Price, Matthew and Mars, Matthijs and McEwen, Jason},
title = {Generative modelling for mass-mapping with fast uncertainty quantification},
year = {2024},
journal={arXiv:2410.24197}
}
```
- **RI-GAN:**
```
@article{marsGenerativeImagingRadioInterferometry,
author = {Mars, Matthijs and Liaudat, Tobías and Whitney, Jessica and McEwen, Jason},
title = {Generative imaging for radio interferometry with fast uncertainty quantification},
year = {},
journal={in prep.}
}
39 changes: 39 additions & 0 deletions configs/radio_meerkat_macro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#Change checkpoint and sense_map path
checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/
data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/

# Define the experience
experience: radio

# Number of code vectors for each phase
num_z_test: 32
num_z_valid: 8
num_z_train: 2

# Data
in_chans: 2 # Real+Imag parts from obs
out_chans: 1
im_size: 360 #384x384 pixel images

# Options
alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling
norm: macro # none, micro, macro

# Optimizer:
lr: 0.001
beta_1: 0
beta_2: 0.99

# Loss weights
gp_weight: 10
adv_weight: 1e-5

# Training
batch_size: 2 # per GPU
accumulate_grad_batches: 2

#Remember to increase this for full training
num_epochs: 100
psnr_gain_tol: 0.25

num_workers: 4
40 changes: 40 additions & 0 deletions configs/radio_meerkat_macro_gradient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#Change checkpoint and sense_map path
checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/
data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/

# Define the experience
experience: radio

# Number of code vectors for each phase
num_z_test: 32
num_z_valid: 8
num_z_train: 2

# Data
in_chans: 2 # Real+Imag parts from obs
out_chans: 1
im_size: 360 #384x384 pixel images

# Options
alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling
norm: macro # none, micro, macro
gradient: True

# Optimizer:
lr: 0.001
beta_1: 0
beta_2: 0.99

# Loss weights
gp_weight: 10
adv_weight: 1e-5

# Training
batch_size: 2 # per GPU
accumulate_grad_batches: 2

#Remember to increase this for full training
num_epochs: 100
psnr_gain_tol: 0.25

num_workers: 4
81 changes: 67 additions & 14 deletions data/datasets/Radio_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,37 @@

class RadioDataset_Test(torch.utils.data.Dataset):
"""Loads the test data."""
def __init__(self, data_dir, transform):
def __init__(self, data_dir, transform, norm='micro'):
"""
Args:
data_dir (path): The path to the dataset.
transform (callable): A callable object (class) that pre-processes the raw data into
appropriate form for it to be fed into the model.
norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples)
"""
self.transform = transform

# Collects the paths of all files.
# Test/x.npy, Test/y.npy, Test/uv.npy
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128)
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64)
self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64)
self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1)


if norm == 'none':
self.transform.mean_x, self.transform.std_x = 0, 1
self.transform.mean_y, self.transform.std_y = 0, 1
self.transform.mean_uv, self.transform.std_uv = 0, 1
elif norm == 'micro':
# if micro we do the normalisation in the transform
pass
elif norm == 'macro':
# load means and stds from train set
self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy"))
self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy"))
self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy"))
self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy"))
self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy"))
self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy"))

def __len__(self):
"""Returns the number of samples in the dataset."""
Expand All @@ -37,21 +52,38 @@ def __getitem__(self,i):

class RadioDataset_Val(torch.utils.data.Dataset):
"""Loads the test data."""
def __init__(self, data_dir, transform):
def __init__(self, data_dir, transform, norm='micro'):
"""
Args:
data_dir (path): The path to the dataset.
transform (callable): A callable object (class) that pre-processes the raw data into
appropriate form for it to be fed into the model.
norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples)
"""
self.transform = transform

# Collects the paths of all files.
# Val/x.npy, Val/y.npy, Val/uv.npy
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128)
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64)
self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64)
self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1)

if norm == 'none':
self.transform.mean_x, self.transform.std_x = 0, 1
self.transform.mean_y, self.transform.std_y = 0, 1
self.transform.mean_uv, self.transform.std_uv = 0, 1
elif norm == 'micro':
# if micro we do the normalisation in the transform
pass
elif norm == 'macro':
# load means and stds from train set
self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy"))
self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy"))
self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy"))
self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy"))
self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy"))
self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy"))


def __len__(self):
"""Returns the number of samples in the dataset."""
Expand All @@ -66,22 +98,43 @@ def __getitem__(self,i):

class RadioDataset_Train(torch.utils.data.Dataset):
"""Loads the test data."""
def __init__(self, data_dir, transform):
def __init__(self, data_dir, transform, norm='micro'):
"""
Args:
data_dir (path): The path to the dataset.
transform (callable): A callable object (class) that pre-processes the raw data into
appropriate form for it to be fed into the model.
norm (str): either 'none' (no normalisation), 'micro' (per sample normalisation), 'macro' (normalisation across all samples)
"""
self.transform = transform

# Collects the paths of all files.
# Train/x.npy, Train/y.npy, Train/uv.npy
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.complex128)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.complex128)
self.x = np.load(data_dir.joinpath("x.npy")).astype(np.float64)
self.y = np.load(data_dir.joinpath("y.npy")).astype(np.float64)
self.uv = np.load(data_dir.joinpath("uv.npy")).real.astype(np.float64)
self.uv = (self.uv - self.uv.min())/(self.uv.max() - self.uv.min()) # normalize range of uv values to (0,1)


if norm == 'none':
self.transform.mean_x, self.transform.std_x = 0, 1
self.transform.mean_y, self.transform.std_y = 0, 1
self.transform.mean_uv, self.transform.std_uv = 0, 1
elif norm == 'micro':
# if micro we do the normalisation in the transform
pass
elif norm == 'macro':
self.transform.mean_x, self.transform.std_x = self.x.mean(), np.mean(self.x.std(axis=(1,2)))
self.transform.mean_y, self.transform.std_y = self.y.mean(), np.mean(self.y.std(axis=(1,2)))
self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), np.mean(self.uv.std(axis=(1,2)))

np.save(data_dir.joinpath("mean_x.npy"), self.transform.mean_x)
np.save(data_dir.joinpath("std_x.npy"), self.transform.std_x)
np.save(data_dir.joinpath("mean_y.npy"), self.transform.mean_y)
np.save(data_dir.joinpath("std_y.npy"), self.transform.std_y)
np.save(data_dir.joinpath("mean_uv.npy"), self.transform.mean_uv)
np.save(data_dir.joinpath("std_uv.npy"), self.transform.std_uv)




def __len__(self):
"""Returns the number of samples in the dataset."""
Expand Down
31 changes: 22 additions & 9 deletions data/lightning/RadioDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def __init__(self, args, test=False, ISNR=30):
self.args = args
self.test = test
self.ISNR = ISNR

self.norm = args.__dict__.get('norm', 'micro')

def __call__(self, data) -> Tuple[float, float, float, float]:
""" Transforms the data.
Expand All @@ -36,21 +38,28 @@ def __call__(self, data) -> Tuple[float, float, float, float]:
x, y, uv = data



# Format input gt data.
pt_x = transforms.to_tensor(x) # Shape (H, W, 2)
pt_x = transforms.to_tensor(x)[:, :, None] # Shape (H, W, 2)
pt_x = pt_x.permute(2, 0, 1) # Shape (2, H, W)
# Format observation data.
pt_y = transforms.to_tensor(y) # Shape (H, W, 2)
pt_y = transforms.to_tensor(y)[:, :, None] # Shape (H, W, 2)
pt_y = pt_y.permute(2, 0, 1) # Shape (2, H, W)
# Format uv data
pt_uv = transforms.to_tensor(uv)[:, :, None] # Shape (H, W, 1)
pt_uv = pt_uv.permute(2, 0, 1) # Shape (1, H, W)
# Normalize everything based on measurements y
normalized_y, mean, std = transforms.normalize_instance(pt_y)
normalized_x = transforms.normalize(pt_x, mean, std)
normalized_uv = transforms.normalize(pt_uv, mean, std)


if self.norm != 'micro':
normalized_y = transforms.normalize(pt_y, self.mean_y, self.std_y) # scale globally
normalized_x = transforms.normalize(pt_x, self.mean_x, self.std_x) # scale globally
normalized_uv = transforms.normalize(pt_uv, self.mean_uv, self.std_uv) # scale globally
mean, std = self.mean_x, self.std_x
elif self.norm == 'micro':
normalized_y, mean, std = transforms.normalize_instance(pt_y)
normalized_x = transforms.normalize(pt_x, mean, std) # scale based on input
normalized_uv, _, _ = transforms.normalize_instance(pt_uv) # scale on intself

# Use normalized stack of y + uv
normalized_y = torch.cat([normalized_y, normalized_uv], dim=0)
Expand All @@ -72,6 +81,7 @@ def __init__(self, args):
super().__init__()
self.prepare_data_per_node = True
self.args = args
self.norm = args.__dict__.get('norm', 'micro')

def prepare_data(self):
pass
Expand All @@ -81,17 +91,20 @@ def setup(self, stage: Optional[str] = None):

train_data = RadioDataset_Train(
data_dir=pathlib.Path(self.args.data_path) / 'train',
transform=RadioDataTransform(self.args, test=False)
transform=RadioDataTransform(self.args, test=False),
norm=self.norm
)

dev_data = RadioDataset_Val(
data_dir=pathlib.Path(self.args.data_path) / 'val',
transform=RadioDataTransform(self.args, test=True)
transform=RadioDataTransform(self.args, test=True),
norm=self.norm
)

test_data = RadioDataset_Test(
data_dir=pathlib.Path(self.args.data_path) / 'test',
transform=RadioDataTransform(self.args, test=True)
transform=RadioDataTransform(self.args, test=True),
norm=self.norm
)

self.train, self.validate, self.test = train_data, dev_data, test_data
Expand Down
Empty file.
Loading