diff --git a/config_template.json b/config_template.json index a0f28e6..ef4aa05 100644 --- a/config_template.json +++ b/config_template.json @@ -98,5 +98,12 @@ "797937_0369619_2017-09-29_0e3a_BGRN_SR_clip_aligned.tif" ], "georeference_output_folder": "processed_data/results_georeference", - "tile_images_output_folder": "processed_data/results_tile_images" + "tile_images_output_folder": "processed_data/results_tile_images", + + "ndwi": { + "window_size": 64, + "blur_kernel": 9, + "blur_sigma": 6, + "majority_threshold": 0.55 + } } \ No newline at end of file diff --git a/visualize_ndwi.py b/visualize_ndwi.py new file mode 100644 index 0000000..3caf782 --- /dev/null +++ b/visualize_ndwi.py @@ -0,0 +1,356 @@ +""" +visualize_ndwi.py + +Visualizes NDWI label quality for the CoastlineExtraction pipeline. +Generates a 4-panel output: Original RGB | NDWI Heatmap | Binary Mask | Overlay + +Usage: + # Single image + python visualize_ndwi.py sample_data/PlanetLabs/image.tif + + # Entire directory + python visualize_ndwi.py sample_data/PlanetLabs/ + + # Custom output directory + python visualize_ndwi.py sample_data/PlanetLabs/ --output_dir results/ndwi_viz +""" + +import os +import sys +import argparse +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import rasterio +import cv2 +import warnings + +warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning) + +from load_config import load_config + + +# ── helpers ────────────────────────────────────────────────────────────────── + +# Default Gaussian blur parameters — matches ndwi_labels.py +KSIZE_BLUR = (9, 9) +SIGMA_X = 6 +SIGMA_Y = 6 + +# Majority threshold — matches ndwi_labels.py MAJORITY_THRESHOLD +MAJORITY_THRESHOLD = 0.55 + + +def load_bands(image_path): + """ + Load Green (band 2) and NIR (last band) from a PlanetLabs GeoTIFF. + Also loads RGB (bands 3,2,1) for display. + Band order: 1=Blue, 2=Green, 3=Red, 4=NIR (PlanetScope) + 1=Blue, 2=Green, 3=Red, 4=RedEdge, 5=NIR (RapidEye) + + Raises: + ValueError: If the image has fewer than 3 bands. + """ + with rasterio.open(image_path) as src: + if src.count < 3: + raise ValueError( + f"{image_path} has only {src.count} band(s). " + "Expected at least 3 (Blue, Green, Red) + NIR." + ) + nir_idx = src.count # NIR is always the last band + blue = src.read(1).astype(np.float32) + green = src.read(2).astype(np.float32) + red = src.read(3).astype(np.float32) + nir = src.read(nir_idx).astype(np.float32) + return blue, green, red, nir + + +def compute_ndwi(green, nir, blur_kernel=None, blur_sigma_x=None, blur_sigma_y=None): + """ + NDWI = (Green - NIR) / (Green + NIR) + Returns values in [-1, 1]; positive → water, negative → land. + + Gaussian blur parameters default to matching ndwi_labels.py: + kernel (9,9), sigmaX=6, sigmaY=6. + + Args: + green: Green band array (float32). + nir: NIR band array (float32). + blur_kernel: Tuple (kH, kW) for GaussianBlur kernel size. + blur_sigma_x: Sigma X for GaussianBlur. + blur_sigma_y: Sigma Y for GaussianBlur. + """ + if blur_kernel is None: + blur_kernel = KSIZE_BLUR + if blur_sigma_x is None: + blur_sigma_x = SIGMA_X + if blur_sigma_y is None: + blur_sigma_y = SIGMA_Y + + np.seterr(divide='ignore', invalid='ignore') + ndwi = (green - nir) / (green + nir) + ndwi = np.nan_to_num(ndwi, nan=0.0) + # Gaussian blur — matches ndwi_labels.py defaults + ndwi = cv2.GaussianBlur(ndwi, blur_kernel, blur_sigma_x, blur_sigma_y) + return ndwi + + +def sliding_window_otsu(ndwi, window_size=64): + """ + Simplified sliding-window Otsu thresholding (pixel-grid, majority voting) + with global mean-threshold fallback for uncovered areas. + + NOTE: This is a simplified approximation of the geometry-based sliding + window in ndwi_labels.py, which uses shapefile transect points as window + centres and rasterio.mask for extraction. Results will differ slightly. + This version is intended for quick visual inspection of NDWI label quality, + not for generating training labels. + + Args: + ndwi: 2-D NDWI array (float, range [-1, 1]). + window_size: Side length of each square window in pixels. + + Returns: + Binary mask (uint8): 1 = water, 0 = land. + """ + h, w = ndwi.shape + vote_map = np.zeros((h, w), dtype=np.int32) + count_map = np.zeros((h, w), dtype=np.int32) + + # Convert NDWI [-1,1] to uint8 matching ndwi_labels.py: (ndwi * 127) + 128 + # NDWI = 0 maps to 128; water (positive NDWI) > 128; land (negative) < 128 + ndwi_8bit = np.clip((ndwi * 127) + 128, 0, 255).astype(np.uint8) + + otsu_thresholds = [] # collect valid thresholds for global fallback + step = window_size // 2 + for y in range(0, h, step): + for x in range(0, w, step): + window = ndwi_8bit[y:y + window_size, x:x + window_size] + if window.size == 0: + continue + # Skip nearly-uniform windows — Otsu on single-class data + # (e.g. all-land) produces a meaningless threshold that splits + # land texture noise as "water", causing scattered white pixels. + # A std of 10 in uint8 space filters out windows that contain + # only one class (pure land ≈ std 3–8, pure water ≈ std 5–10). + if window.std() < 10: + continue + thresh, _ = cv2.threshold( + window, 0, 1, + cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + otsu_thresholds.append(thresh) + classified = (window >= thresh).astype(np.int32) + vote_map[y:y + window_size, x:x + window_size] += classified + count_map[y:y + window_size, x:x + window_size] += 1 + + # Majority vote for pixels covered by valid sliding windows + # Threshold matches ndwi_labels.py MAJORITY_THRESHOLD (0.55) + safe_count = np.where(count_map == 0, 1, count_map) + mask = (vote_map / safe_count >= MAJORITY_THRESHOLD).astype(np.uint8) + + # For uncovered pixels, fall back to a global mean threshold. + # This matches ndwi_labels.py's concatenation logic: sliding-window + # labels where windows exist, single mean-threshold classification + # elsewhere. The +10 offset biases toward land, matching ndwi_labels.py. + if otsu_thresholds: + mean_thresh = np.mean(otsu_thresholds) + 10 + global_mask = (ndwi_8bit >= mean_thresh).astype(np.uint8) + covered = count_map > 0 + mask = np.where(covered, mask, global_mask) + else: + # No valid windows — use a single global Otsu threshold + global_thresh, _ = cv2.threshold( + ndwi_8bit, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + mask = (ndwi_8bit >= global_thresh).astype(np.uint8) + + return mask + + +def normalize_for_display(arr): + """Stretch array to [0, 1] for display.""" + lo, hi = arr.min(), arr.max() + if hi == lo: + return np.zeros_like(arr, dtype=np.float32) + return (arr - lo) / (hi - lo) + + +def build_rgb(red, green, blue): + """Stack and percentile-stretch to a display-ready RGB image.""" + rgb = np.stack([red, green, blue], axis=-1) + lo, hi = np.percentile(rgb, 2), np.percentile(rgb, 98) + rgb = np.clip((rgb - lo) / (hi - lo + 1e-9), 0, 1) + return rgb + + +def print_ndwi_stats(ndwi, binary_mask, image_name): + """Print statistics to help tune thresholds.""" + water_pct = binary_mask.mean() * 100 + print(f"\n── NDWI Stats: {image_name} ──") + print(f" Min : {ndwi.min():.4f}") + print(f" Max : {ndwi.max():.4f}") + print(f" Mean : {ndwi.mean():.4f}") + print(f" Std : {ndwi.std():.4f}") + print(f" Water pixels : {water_pct:.1f}%") + print(f" Land pixels : {100 - water_pct:.1f}%") + + +# ── main visualisation ──────────────────────────────────────────────────────── + +def visualize_single(image_path, output_dir, window_size): + """ + Generate and save a 4-panel NDWI visualisation for one image. + Panels: Original RGB | NDWI Heatmap | Binary Mask | Overlay + """ + image_name = os.path.splitext(os.path.basename(image_path))[0] + os.makedirs(output_dir, exist_ok=True) + + # 1. Load bands + blue, green, red, nir = load_bands(image_path) + + # 2. Compute NDWI + binary mask + ndwi = compute_ndwi(green, nir) + binary_mask = sliding_window_otsu(ndwi, window_size=window_size) + + # 3. Stats + print_ndwi_stats(ndwi, binary_mask, image_name) + + # 4. Build display arrays + rgb = build_rgb(red, green, blue) + ndwi_display = normalize_for_display(ndwi) + + # Colour overlay: water=blue, land=tan + overlay = rgb.copy() + overlay[binary_mask == 1] = [0.20, 0.45, 0.85] # blue → water + overlay[binary_mask == 0] = overlay[binary_mask == 0] * 0.6 + np.array([0.6, 0.5, 0.3]) * 0.4 + + # 5. Plot + fig, axes = plt.subplots(1, 4, figsize=(20, 5)) + fig.suptitle(f"NDWI Label Quality — {image_name}", fontsize=13, fontweight='bold') + + axes[0].imshow(rgb) + axes[0].set_title("Original RGB") + axes[0].axis('off') + + im = axes[1].imshow(ndwi_display, cmap='RdYlBu', vmin=0, vmax=1) + axes[1].set_title("NDWI Heatmap\n(blue=water, red=land)") + axes[1].axis('off') + fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04) + + axes[2].imshow(binary_mask, cmap='gray', vmin=0, vmax=1) + axes[2].set_title("Binary Mask\n(white=water, black=land)") + axes[2].axis('off') + + axes[3].imshow(overlay) + axes[3].set_title("Overlay") + axes[3].axis('off') + water_patch = mpatches.Patch(color=[0.20, 0.45, 0.85], label='Water') + axes[3].legend(handles=[water_patch], loc='lower right', fontsize=8) + + plt.tight_layout() + + out_path = os.path.join(output_dir, f"{image_name}_ndwi_viz.png") + plt.savefig(out_path, dpi=150, bbox_inches='tight') + plt.close() + print(f" Saved → {out_path}") + return out_path + + +def visualize_directory(input_dir, output_dir, window_size): + """Process all .tif files in a directory.""" + tif_files = [ + os.path.join(input_dir, f) + for f in sorted(os.listdir(input_dir)) + if f.lower().endswith('.tif') and 'udm' not in f.lower() + ] + if not tif_files: + print(f"No .tif files found in {input_dir}") + return + + print(f"Found {len(tif_files)} image(s) in {input_dir}") + for path in tif_files: + print(f"\nProcessing: {os.path.basename(path)}") + try: + visualize_single(path, output_dir, window_size) + except Exception as e: + print(f" ERROR: {e}") + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + +def parse_args(): + parser = argparse.ArgumentParser( + description="Visualize NDWI label quality for CoastlineExtraction pipeline." + ) + parser.add_argument( + "input", + help="Path to a single .tif image or a directory of .tif images." + ) + parser.add_argument( + "--output_dir", + default="ndwi_visualizations", + help="Directory to save PNG outputs (default: ndwi_visualizations/)." + ) + parser.add_argument( + "--window_size", + type=int, + default=None, + help="Sliding window size for Otsu thresholding (default: from config or 64)." + ) + parser.add_argument( + "--blur_kernel", + type=int, + default=None, + help="Gaussian blur kernel size, e.g. 9 for (9,9) (default: from config or 9)." + ) + parser.add_argument( + "--blur_sigma", + type=float, + default=None, + help="Gaussian blur sigma (default: from config or 6)." + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Load NDWI parameters from config, with CLI overrides + try: + config = load_config() + ndwi_cfg = config.get("ndwi", {}) + except Exception: + ndwi_cfg = {} + + window_size = args.window_size or ndwi_cfg.get("window_size", 64) + + blur_k = args.blur_kernel or ndwi_cfg.get("blur_kernel", 9) + blur_kernel = (blur_k, blur_k) + + blur_sigma = args.blur_sigma + if blur_sigma is None: + blur_sigma = ndwi_cfg.get("blur_sigma", 6) + blur_sigma_x = blur_sigma + blur_sigma_y = blur_sigma + + # Stash in module-level defaults so compute_ndwi picks them up + global KSIZE_BLUR, SIGMA_X, SIGMA_Y + KSIZE_BLUR = blur_kernel + SIGMA_X = int(blur_sigma_x) + SIGMA_Y = int(blur_sigma_y) + + print(f"Using window_size={window_size}, blur_kernel={blur_kernel}, " + f"blur_sigma=({blur_sigma_x}, {blur_sigma_y})") + + if os.path.isdir(args.input): + visualize_directory(args.input, args.output_dir, window_size) + elif os.path.isfile(args.input): + visualize_single(args.input, args.output_dir, window_size) + else: + print(f"ERROR: '{args.input}' is not a valid file or directory.") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file