diff --git a/XPointMLTest.py b/XPointMLTest.py index c5e6d43..a2331db 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -14,6 +14,8 @@ import torch.optim as optim import torch.nn.functional as F +from torchvision.transforms import v2 # rotate tensor + from torch.utils.data import DataLoader, Dataset from timeit import default_timer as timer @@ -57,6 +59,41 @@ def expand_xpoints_mask(binary_mask, kernel_size=9): return expanded_mask +def rotate(frameData,deg): + if deg not in [90, 180, 270]: + print(f"invalid rotation specified... exiting") + sys.exit() + psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) + mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) + return { + "fnum": frameData["fnum"], + "rotation": deg, + "reflectionAxis": -1, # no reflection + "psi": psi, + "mask": mask, + "x": frameData["x"], + "y": frameData["y"], + "filenameBase": frameData["filenameBase"], + "params": frameData["params"] + } + +def reflect(frameData,axis): + if axis not in [0,1]: + print(f"invalid reflection axis specified... exiting") + sys.exit() + psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) + mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) + return { + "fnum": frameData["fnum"], + "rotation": 0, + "reflectionAxis": axis, + "psi": psi, + "mask": mask, + "x": frameData["x"], + "y": frameData["y"], + "filenameBase": frameData["filenameBase"], + "params": frameData["params"] + } # DATASET DEFINITION class XPointDataset(Dataset): @@ -69,7 +106,7 @@ class XPointDataset(Dataset): - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. """ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, - saveFig=1, xptCacheDir=None): + saveFig=1, xptCacheDir=None, rotateAndReflect=False): """ paramFile: Path to parameter file (string). fnumList: List of frames to iterate. @@ -100,20 +137,23 @@ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, self.params["symBar"] = 1 self.params["colormap"] = 'bwr' - # output directory: - self.outDir = "plots" - os.makedirs(self.outDir, exist_ok=True) # load all the data self.data = [] for fnum in fnumList: - self.data.append(self.load(fnum)) + frameData = self.load(fnum) + self.data.append(frameData) + if rotateAndReflect: + self.data.append(rotate(frameData,90)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) def __len__(self): - return len(self.fnumList) + return len(self.data) def __getitem__(self, idx): - fnum = self.fnumList[idx] return self.data[idx] def load(self, fnum): @@ -221,6 +261,8 @@ def load(self, fnum): return { "fnum": fnum, + "rotation": 0, + "reflectionAxis": -1, # no reflection "psi": psi_torch, # shape [1, Nx, Ny] "mask": mask_torch, # shape [1, Nx, Ny] // Used in: psi, mask = batch["psi"].to(device), batch["mask"].to(device) "x": x, @@ -408,7 +450,8 @@ def forward(self, inputs, targets): return 1.0 - dice # PLOTTING FUNCTION -def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, interpFac, +def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, + reflectionAxis, filenameBase, interpFac, xpoint_mask=None, titleExtra="", outDir="plots", @@ -436,7 +479,8 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, inte if params["axisEqual"]: plt.gca().set_aspect("equal", "box") - plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}") + plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}, " + f"reflectionAxis={reflectionAxis}") # Overlay X-points if xpoint_mask is given if xpoint_mask is not None: @@ -454,7 +498,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, inte basename = os.path.basename(filenameBase) saveFilename = os.path.join( outDir, - f"{basename}_interpFac_{interpFac}_{fnum:04d}{titleExtra.replace(' ','_')}.png" + f"{basename}_interpFac_{interpFac}_frame{fnum:04d}_rotation{rotation}_reflection{reflectionAxis}_{titleExtra.replace(' ','_')}.png" ) plt.savefig(saveFilename, dpi=300) print(" Figure written to", saveFilename) @@ -546,7 +590,7 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi plt.close() -def plot_training_history(train_losses, val_losses, save_path='output_images/training_history.png'): +def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'): """ Plots training and validation losses across epochs. @@ -602,6 +646,10 @@ def parseCommandLineArgs(): specify the path to a directory that will be used to cache the outputs of the analytic Xpoint finder ''') + parser.add_argument('--plot', type=bool, default=False, + help='create figures of the ground truth X-points and model identified X-points') + parser.add_argument('--plotDir', type=Path, default="./plots", + help='directory where figures are written') args = parser.parse_args() return args @@ -642,12 +690,17 @@ def main(): args = parseCommandLineArgs() checkCommandLineArgs(args) + # output directory: + outDir = args.plotDir + os.makedirs(outDir, exist_ok=True) + t0 = timer() train_fnums = range(args.trainFrameFirst, args.trainFrameLast) val_fnums = range(args.validationFrameFirst, args.validationFrameLast) train_dataset = XPointDataset(args.paramFile, train_fnums, constructJz=1, - interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir) + interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir, + rotateAndReflect=True) val_dataset = XPointDataset(args.paramFile, val_fnums, constructJz=1, interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir) @@ -664,6 +717,7 @@ def main(): criterion = DiceLoss(smooth=1.0) optimizer = optim.Adam(model.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96) t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) @@ -675,8 +729,12 @@ def main(): for epoch in range(num_epochs): train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) val_loss.append(validate_one_epoch(model, val_loader, criterion, device)) - print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}") + lr = scheduler.get_last_lr() + scheduler.step() # update lr with gamma + print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} " + f"ValLoss={val_loss[-1]} LearningRate {lr}") + plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) requiredLossDecreaseMagnitude = args.minTrainingLoss @@ -687,8 +745,7 @@ def main(): # (D) Plotting after training model.eval() - outDir = "output_images" - os.makedirs(outDir, exist_ok=True) + outDir = "plots" interpFac = 1 # Evaluate on combined set for demonstration. Exam this part to see if save to remove @@ -702,6 +759,8 @@ def main(): for item in set: # item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params fnum = item["fnum"] + rotation = item["rotation"] + reflectionAxis = item["reflectionAxis"] psi_np = np.array(item["psi"])[0] mask_gt = np.array(item["mask"])[0] x = item["x"] @@ -721,7 +780,7 @@ def main(): pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) # Thresholding at 0.5, can be fine tune - print(f"Frame {fnum}:") + print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}") print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}") print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}") @@ -729,30 +788,31 @@ def main(): print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") print(f" Binary Mask (X_points) - shape: {pred_mask_bin.shape}, min: {pred_mask_bin.min()}, max: {pred_mask_bin.max()}") - # Plot GROUND TRUTH - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, filenameBase, interpFac, - xpoint_mask=mask_gt, - titleExtra="(GT X-points)", - outDir=outDir, - saveFig=True - ) - - # Plot CNN PREDICTIONS - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, filenameBase, interpFac, - xpoint_mask=np.squeeze(pred_mask_bin), - titleExtra="(CNN X-points)", - outDir=outDir, - saveFig=True - ) - - pred_prob_np_full = pred_prob.cpu().numpy() - plot_model_performance( - psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, - outDir=outDir, - saveFig=True - ) + if args.plot : + # Plot GROUND TRUTH + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=mask_gt, + titleExtra="GTXpoints", + outDir=outDir, + saveFig=True + ) + + # Plot CNN PREDICTIONS + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, + xpoint_mask=np.squeeze(pred_mask_bin), + titleExtra="CNNXpoints", + outDir=outDir, + saveFig=True + ) + + pred_prob_np_full = pred_prob.cpu().numpy() + plot_model_performance( + psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, + outDir=outDir, + saveFig=True + ) t5 = timer() print("time (s) to apply model: " + str(t5-t4))