Skip to content

Commit 4c5c8d2

Browse files
committed
update and add RI-GAN models and scripts
1 parent bb35883 commit 4c5c8d2

File tree

17 files changed

+2086
-433
lines changed

17 files changed

+2086
-433
lines changed

configs/radio_meerkat_macro.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#Change checkpoint and sense_map path
2+
checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/
3+
data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/
4+
5+
# Define the experience
6+
experience: radio
7+
8+
# Number of code vectors for each phase
9+
num_z_test: 32
10+
num_z_valid: 8
11+
num_z_train: 2
12+
13+
# Data
14+
in_chans: 2 # Real+Imag parts from obs
15+
out_chans: 1
16+
im_size: 360 #384x384 pixel images
17+
18+
# Options
19+
alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling
20+
norm: macro # none, micro, macro
21+
22+
# Optimizer:
23+
lr: 0.001
24+
beta_1: 0
25+
beta_2: 0.99
26+
27+
# Loss weights
28+
gp_weight: 10
29+
adv_weight: 1e-5
30+
31+
# Training
32+
batch_size: 2 # per GPU
33+
accumulate_grad_batches: 2
34+
35+
#Remember to increase this for full training
36+
num_epochs: 100
37+
psnr_gain_tol: 0.25
38+
39+
num_workers: 4
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#Change checkpoint and sense_map path
2+
checkpoint_dir: /share/gpu0/mars/TNG_data/rcGAN/models/meerkat_macro/
3+
data_path: /share/gpu0/mars/TNG_data/rcGAN/meerkat_clean/
4+
5+
# Define the experience
6+
experience: radio
7+
8+
# Number of code vectors for each phase
9+
num_z_test: 32
10+
num_z_valid: 8
11+
num_z_train: 2
12+
13+
# Data
14+
in_chans: 2 # Real+Imag parts from obs
15+
out_chans: 1
16+
im_size: 360 #384x384 pixel images
17+
18+
# Options
19+
alt_upsample: False # False -> convt upsampling, True -> interpolate upsampling
20+
norm: macro # none, micro, macro
21+
gradient: True
22+
23+
# Optimizer:
24+
lr: 0.001
25+
beta_1: 0
26+
beta_2: 0.99
27+
28+
# Loss weights
29+
gp_weight: 10
30+
adv_weight: 1e-5
31+
32+
# Training
33+
batch_size: 2 # per GPU
34+
accumulate_grad_batches: 2
35+
36+
#Remember to increase this for full training
37+
num_epochs: 100
38+
psnr_gain_tol: 0.25
39+
40+
num_workers: 4

data/datasets/Radio_data.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,17 @@ def __init__(self, data_dir, transform, norm='micro'):
3030
# if micro we do the normalisation in the transform
3131
pass
3232
elif norm == 'macro':
33-
self.transform.mean_x, self.transform.std_x = self.x.mean(), self.x.std()
34-
self.transform.mean_y, self.transform.std_y = self.y.mean(), self.y.std()
35-
self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), self.uv.std()
33+
# load means and stds from train set
34+
self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy"))
35+
self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy"))
36+
self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy"))
37+
self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy"))
38+
self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy"))
39+
self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy"))
40+
41+
# self.transform.mean_x, self.transform.std_x = self.x.mean(), self.x.std()
42+
# self.transform.mean_y, self.transform.std_y = self.y.mean(), self.y.std()
43+
# self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), self.uv.std()
3644

3745
def __len__(self):
3846
"""Returns the number of samples in the dataset."""
@@ -72,9 +80,17 @@ def __init__(self, data_dir, transform, norm='micro'):
7280
# if micro we do the normalisation in the transform
7381
pass
7482
elif norm == 'macro':
75-
self.transform.mean_x, self.transform.std_x = self.x.mean(), self.x.std()
76-
self.transform.mean_y, self.transform.std_y = self.y.mean(), self.y.std()
77-
self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), self.uv.std()
83+
# load means and stds from train set
84+
self.transform.mean_x = np.load(data_dir.parent.joinpath("train/mean_x.npy"))
85+
self.transform.std_x = np.load(data_dir.parent.joinpath("train/std_x.npy"))
86+
self.transform.mean_y = np.load(data_dir.parent.joinpath("train/mean_y.npy"))
87+
self.transform.std_y = np.load(data_dir.parent.joinpath("train/std_y.npy"))
88+
self.transform.mean_uv = np.load(data_dir.parent.joinpath("train/mean_uv.npy"))
89+
self.transform.std_uv = np.load(data_dir.parent.joinpath("train/std_uv.npy"))
90+
91+
# self.transform.mean_x, self.transform.std_x = self.x.mean(), self.x.std()
92+
# self.transform.mean_y, self.transform.std_y = self.y.mean(), self.y.std()
93+
# self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), self.uv.std()
7894

7995
def __len__(self):
8096
"""Returns the number of samples in the dataset."""
@@ -116,6 +132,16 @@ def __init__(self, data_dir, transform, norm='micro'):
116132
self.transform.mean_x, self.transform.std_x = self.x.mean(), np.mean(self.x.std(axis=(1,2)))
117133
self.transform.mean_y, self.transform.std_y = self.y.mean(), np.mean(self.y.std(axis=(1,2)))
118134
self.transform.mean_uv, self.transform.std_uv = self.uv.mean(), np.mean(self.uv.std(axis=(1,2)))
135+
136+
np.save(data_dir.joinpath("mean_x.npy"), self.transform.mean_x)
137+
np.save(data_dir.joinpath("std_x.npy"), self.transform.std_x)
138+
np.save(data_dir.joinpath("mean_y.npy"), self.transform.mean_y)
139+
np.save(data_dir.joinpath("std_y.npy"), self.transform.std_y)
140+
np.save(data_dir.joinpath("mean_uv.npy"), self.transform.mean_uv)
141+
np.save(data_dir.joinpath("std_uv.npy"), self.transform.std_uv)
142+
143+
144+
119145

120146
def __len__(self):
121147
"""Returns the number of samples in the dataset."""

evaluation_scripts/radio_cfid/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)