Skip to content

exponential learning rate #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
140 changes: 100 additions & 40 deletions XPointMLTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch.optim as optim
import torch.nn.functional as F

from torchvision.transforms import v2 # rotate tensor

from torch.utils.data import DataLoader, Dataset

from timeit import default_timer as timer
Expand Down Expand Up @@ -57,6 +59,41 @@ def expand_xpoints_mask(binary_mask, kernel_size=9):

return expanded_mask

def rotate(frameData,deg):
if deg not in [90, 180, 270]:
print(f"invalid rotation specified... exiting")
sys.exit()
psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR)
mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR)
return {
"fnum": frameData["fnum"],
"rotation": deg,
"reflectionAxis": -1, # no reflection
"psi": psi,
"mask": mask,
"x": frameData["x"],
"y": frameData["y"],
"filenameBase": frameData["filenameBase"],
"params": frameData["params"]
}

def reflect(frameData,axis):
if axis not in [0,1]:
print(f"invalid reflection axis specified... exiting")
sys.exit()
psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0)
mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0)
return {
"fnum": frameData["fnum"],
"rotation": 0,
"reflectionAxis": axis,
"psi": psi,
"mask": mask,
"x": frameData["x"],
"y": frameData["y"],
"filenameBase": frameData["filenameBase"],
"params": frameData["params"]
}

# DATASET DEFINITION
class XPointDataset(Dataset):
Expand All @@ -69,7 +106,7 @@ class XPointDataset(Dataset):
- Returns (psiTensor, maskTensor) as a PyTorch (float) pair.
"""
def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1,
saveFig=1, xptCacheDir=None):
saveFig=1, xptCacheDir=None, rotateAndReflect=False):
"""
paramFile: Path to parameter file (string).
fnumList: List of frames to iterate.
Expand Down Expand Up @@ -100,20 +137,23 @@ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1,
self.params["symBar"] = 1
self.params["colormap"] = 'bwr'

# output directory:
self.outDir = "plots"
os.makedirs(self.outDir, exist_ok=True)

# load all the data
self.data = []
for fnum in fnumList:
self.data.append(self.load(fnum))
frameData = self.load(fnum)
self.data.append(frameData)
if rotateAndReflect:
self.data.append(rotate(frameData,90))
self.data.append(rotate(frameData,180))
self.data.append(rotate(frameData,270))
self.data.append(reflect(frameData,0))
self.data.append(reflect(frameData,1))

def __len__(self):
return len(self.fnumList)
return len(self.data)

def __getitem__(self, idx):
fnum = self.fnumList[idx]
return self.data[idx]

def load(self, fnum):
Expand Down Expand Up @@ -221,6 +261,8 @@ def load(self, fnum):

return {
"fnum": fnum,
"rotation": 0,
"reflectionAxis": -1, # no reflection
"psi": psi_torch, # shape [1, Nx, Ny]
"mask": mask_torch, # shape [1, Nx, Ny] // Used in: psi, mask = batch["psi"].to(device), batch["mask"].to(device)
"x": x,
Expand Down Expand Up @@ -408,7 +450,8 @@ def forward(self, inputs, targets):
return 1.0 - dice

# PLOTTING FUNCTION
def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, interpFac,
def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation,
reflectionAxis, filenameBase, interpFac,
xpoint_mask=None,
titleExtra="",
outDir="plots",
Expand Down Expand Up @@ -436,7 +479,8 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, inte
if params["axisEqual"]:
plt.gca().set_aspect("equal", "box")

plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}")
plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}, "
f"reflectionAxis={reflectionAxis}")

# Overlay X-points if xpoint_mask is given
if xpoint_mask is not None:
Expand All @@ -454,7 +498,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, inte
basename = os.path.basename(filenameBase)
saveFilename = os.path.join(
outDir,
f"{basename}_interpFac_{interpFac}_{fnum:04d}{titleExtra.replace(' ','_')}.png"
f"{basename}_interpFac_{interpFac}_frame{fnum:04d}_rotation{rotation}_reflection{reflectionAxis}_{titleExtra.replace(' ','_')}.png"
)
plt.savefig(saveFilename, dpi=300)
print(" Figure written to", saveFilename)
Expand Down Expand Up @@ -546,7 +590,7 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi

plt.close()

def plot_training_history(train_losses, val_losses, save_path='output_images/training_history.png'):
def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'):
"""
Plots training and validation losses across epochs.

Expand Down Expand Up @@ -602,6 +646,10 @@ def parseCommandLineArgs():
specify the path to a directory that will be used to cache
the outputs of the analytic Xpoint finder
''')
parser.add_argument('--plot', type=bool, default=False,
help='create figures of the ground truth X-points and model identified X-points')
parser.add_argument('--plotDir', type=Path, default="./plots",
help='directory where figures are written')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -642,12 +690,17 @@ def main():
args = parseCommandLineArgs()
checkCommandLineArgs(args)

# output directory:
outDir = args.plotDir
os.makedirs(outDir, exist_ok=True)

t0 = timer()
train_fnums = range(args.trainFrameFirst, args.trainFrameLast)
val_fnums = range(args.validationFrameFirst, args.validationFrameLast)

train_dataset = XPointDataset(args.paramFile, train_fnums, constructJz=1,
interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir)
interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir,
rotateAndReflect=True)
val_dataset = XPointDataset(args.paramFile, val_fnums, constructJz=1,
interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir)

Expand All @@ -664,6 +717,7 @@ def main():

criterion = DiceLoss(smooth=1.0)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)

t2 = timer()
print("time (s) to prepare model: " + str(t2-t1))
Expand All @@ -675,8 +729,12 @@ def main():
for epoch in range(num_epochs):
train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device))
val_loss.append(validate_one_epoch(model, val_loader, criterion, device))
print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}")
lr = scheduler.get_last_lr()
scheduler.step() # update lr with gamma
print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} "
f"ValLoss={val_loss[-1]} LearningRate {lr}")

plot_training_history(train_loss, val_loss)
print("time (s) to train model: " + str(timer()-t2))

requiredLossDecreaseMagnitude = args.minTrainingLoss
Expand All @@ -687,8 +745,7 @@ def main():

# (D) Plotting after training
model.eval()
outDir = "output_images"
os.makedirs(outDir, exist_ok=True)
outDir = "plots"
interpFac = 1

# Evaluate on combined set for demonstration. Exam this part to see if save to remove
Expand All @@ -702,6 +759,8 @@ def main():
for item in set:
# item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params
fnum = item["fnum"]
rotation = item["rotation"]
reflectionAxis = item["reflectionAxis"]
psi_np = np.array(item["psi"])[0]
mask_gt = np.array(item["mask"])[0]
x = item["x"]
Expand All @@ -721,38 +780,39 @@ def main():

pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) # Thresholding at 0.5, can be fine tune

print(f"Frame {fnum}:")
print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:")
print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}")
print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}")
print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}")
print(f" Probabilities (after sigmoid) - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}")
print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels")
print(f" Binary Mask (X_points) - shape: {pred_mask_bin.shape}, min: {pred_mask_bin.min()}, max: {pred_mask_bin.max()}")

# Plot GROUND TRUTH
plot_psi_contours_and_xpoints(
psi_np, x, y, params, fnum, filenameBase, interpFac,
xpoint_mask=mask_gt,
titleExtra="(GT X-points)",
outDir=outDir,
saveFig=True
)

# Plot CNN PREDICTIONS
plot_psi_contours_and_xpoints(
psi_np, x, y, params, fnum, filenameBase, interpFac,
xpoint_mask=np.squeeze(pred_mask_bin),
titleExtra="(CNN X-points)",
outDir=outDir,
saveFig=True
)

pred_prob_np_full = pred_prob.cpu().numpy()
plot_model_performance(
psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase,
outDir=outDir,
saveFig=True
)
if args.plot :
# Plot GROUND TRUTH
plot_psi_contours_and_xpoints(
psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac,
xpoint_mask=mask_gt,
titleExtra="GTXpoints",
outDir=outDir,
saveFig=True
)

# Plot CNN PREDICTIONS
plot_psi_contours_and_xpoints(
psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac,
xpoint_mask=np.squeeze(pred_mask_bin),
titleExtra="CNNXpoints",
outDir=outDir,
saveFig=True
)

pred_prob_np_full = pred_prob.cpu().numpy()
plot_model_performance(
psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase,
outDir=outDir,
saveFig=True
)

t5 = timer()
print("time (s) to apply model: " + str(t5-t4))
Expand Down