-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPreprocessing.py
165 lines (136 loc) · 5.98 KB
/
Preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import h5py
from sklearn.model_selection import train_test_split
import argparse
import time
import psutil
from noise2noise.noise2noise import Noise2Noise3D
from noise2noise.datasets import load_dataset # or define your own if needed
##############################################
# Memory logging helper
##############################################
def log_memory_usage(stage=""):
"""
Logs the memory usage of the current process and system.
Args:
stage (str): A message indicating where in the pipeline this log occurs.
"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
rss_gb = mem_info.rss / (1024 ** 3) # Resident Set Size in GB
vms_gb = mem_info.vms / (1024 ** 3) # Virtual Memory Size in GB
total_mem = psutil.virtual_memory().total / (1024 ** 3)
used_mem = psutil.virtual_memory().used / (1024 ** 3)
print(f"[{time.strftime('%H:%M:%S')}] [{stage}] Process RSS: {rss_gb:.2f} GB, VMS: {vms_gb:.2f} GB | System Used: {used_mem:.2f} GB / {total_mem:.2f} GB")
##############################################
# Argument class
##############################################
class Args:
def __init__(self):
# Training hyperparams
self.learning_rate = 1e-3
self.adam = [0.9, 0.99, 1e-8]
self.nb_epochs = 50
self.loss = 'l2' # or 'l2'
self.cuda = torch.cuda.is_available()
# Batch size logic
# We'll do batch_size = max_batch_size_per_gpu * num_gpus
self.max_batch_size_per_gpu = 20
self.num_gpus = torch.cuda.device_count() if self.cuda else 0
self.batch_size = max(self.num_gpus, 1) * self.max_batch_size_per_gpu
# Data / checkpoint paths
self.ckptPath = "ckptsHighRes" # you can override
self.ckpt_overwrite = True
# For dataset
self.train_size = None
self.valid_size = None
self.crop_size = 64
self.stride = 32
self.no_crop = False
self.add_noise = False
self.seed = 42
##############################################
# Patch extraction function
##############################################
def create_patches_dual(volA, volB, patch_size=(64,64,64), stride=(64,64,64)):
"""
Very simple patch function. You can do overlap if you want or data augmentation.
For demonstration, we do non-overlapping.
"""
d, h, w = volA.shape
ps_d, ps_h, ps_w = patch_size
st_d, st_h, st_w = stride
patchesA = []
patchesB = []
for z in range(0, d - ps_d + 1, st_d):
for y in range(0, h - ps_h + 1, st_h):
for x in range(0, w - ps_w + 1, st_w):
patchA = volA[z:z+ps_d, y:y+ps_h, x:x+ps_w]
patchB = volB[z:z+ps_d, y:y+ps_h, x:x+ps_w]
patchesA.append(patchA)
patchesB.append(patchB)
patchesA = np.array(patchesA, dtype=np.float32)
patchesB = np.array(patchesB, dtype=np.float32)
return patchesA, patchesB
##############################################
# HDF5 volume loader
##############################################
def load_hdf5_volume(path):
with h5py.File(path, 'r') as f:
vol = f['Volume'][:]
# Optionally normalize or do something. Just returning as float32:
vol = vol.astype(np.float32)
return vol
##############################################
# Main function
##############################################
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--trainA', default="/lhome/ahmadfn/ICT2025/Paper/ConeScan/ReducedVolumes/reco64EvenDown.hdf5", type=str, help="Path to volume A (noisy version 1)")
parser.add_argument('--trainB', default="/lhome/ahmadfn/ICT2025/Paper/ConeScan/ReducedVolumes/reco64OddDown.hdf5", type=str, help="Path to volume B (noisy version 2)")
parser.add_argument('--ckptPath', type=str, default="ckpts", help="Where to save checkpoints")
args_cli = parser.parse_args()
params = Args() # our argument class
if args_cli.ckptPath:
params.ckptPath = args_cli.ckptPath
log_memory_usage("Before Loading Volumes")
# 1) Load volumes
volA = load_hdf5_volume(args_cli.trainA)
volB = load_hdf5_volume(args_cli.trainB)
log_memory_usage("After Loading Volumes")
# 2) Create patches
log_memory_usage("Before Patch Creation")
patchesA, patchesB = create_patches_dual(volA, volB, patch_size=(params.crop_size,)*3)
log_memory_usage("After Patch Creation")
# 3) Train/validation split
n_patches = patchesA.shape[0]
indices = np.arange(n_patches)
train_idx, val_idx = train_test_split(indices, test_size=0.1, random_state=params.seed)
trainA = patchesA[train_idx]
trainB = patchesB[train_idx]
valA = patchesA[val_idx]
valB = patchesB[val_idx]
# 4) Convert to Tensors and wrap in DataLoaders
log_memory_usage("Before Tensor Conversion")
train_source_data = torch.tensor(trainA, dtype=torch.float32).unsqueeze(1) # (N,1,D,H,W)
train_target_data = torch.tensor(trainB, dtype=torch.float32).unsqueeze(1)
val_source_data = torch.tensor(valA, dtype=torch.float32).unsqueeze(1)
val_target_data = torch.tensor(valB, dtype=torch.float32).unsqueeze(1)
log_memory_usage("After Tensor Conversion")
train_ds = torch.utils.data.TensorDataset(train_source_data, train_target_data)
val_ds = torch.utils.data.TensorDataset(val_source_data, val_target_data)
log_memory_usage("Before DataLoader Creation")
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=params.batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=params.batch_size, shuffle=False)
log_memory_usage("After DataLoader Creation")
# 5) Initialize Noise2Noise3D & train
n2n = Noise2Noise3D(params, trainable=True)
log_memory_usage("Before Training Start")
n2n.train(train_loader, val_loader)
log_memory_usage("After Training End")
if __name__ == "__main__":
main()