From d9bf7ccd06def3e139e64d00562fe894b7f1606a Mon Sep 17 00:00:00 2001 From: Pratik-Doshi-99 Date: Wed, 23 Jul 2025 00:11:30 +0000 Subject: [PATCH 1/3] built basic scaffold --- examples/clay_embeddings_example.py | 226 +++++++++++ samgeo/__init__.py | 1 + samgeo/clay.py | 555 ++++++++++++++++++++++++++++ samgeo/clay_metadata.yaml | 295 +++++++++++++++ 4 files changed, 1077 insertions(+) create mode 100644 examples/clay_embeddings_example.py create mode 100644 samgeo/clay.py create mode 100644 samgeo/clay_metadata.yaml diff --git a/examples/clay_embeddings_example.py b/examples/clay_embeddings_example.py new file mode 100644 index 00000000..920deac3 --- /dev/null +++ b/examples/clay_embeddings_example.py @@ -0,0 +1,226 @@ +""" +Example script demonstrating Clay foundation model embeddings with segment-geospatial. + +This script shows how to: +1. Load a geospatial image +2. Generate Clay foundation model embeddings +3. Save and load embeddings +4. Visualize embedding results + +Requirements: +- Clay model checkpoint file +- Geospatial imagery (GeoTIFF, etc.) +- Clay model dependencies: claymodel, torch, torchvision, pyyaml, python-box +""" + +import os +import numpy as np +import matplotlib.pyplot as plt +from samgeo import Clay, load_embeddings + + +def main(): + # Configuration + CHECKPOINT_PATH = "path/to/clay-model-checkpoint.ckpt" # Update this path + IMAGE_PATH = "path/to/your/satellite_image.tif" # Update this path + OUTPUT_DIR = "clay_embeddings_output" + + # Create output directory + os.makedirs(OUTPUT_DIR, exist_ok=True) + + print("=== Clay Foundation Model Embeddings Example ===\n") + + # Step 1: Initialize Clay embeddings model + print("1. Initializing Clay model...") + try: + clay = Clay( + checkpoint_path=CHECKPOINT_PATH, + device="auto", # Will use GPU if available + mask_ratio=0.0, # No masking for inference + shuffle=False + ) + print(" ✓ Clay model loaded successfully") + except Exception as e: + print(f" ✗ Error loading Clay model: {e}") + print(" Please ensure you have:") + print(" - Valid Clay checkpoint file") + print(" - Clay dependencies: pip install claymodel torch torchvision pyyaml python-box") + return + + # Step 2: Load and analyze image + print("\n2. Loading geospatial image...") + try: + clay.set_image( + source=IMAGE_PATH, + # sensor_type="sentinel-2-l2a", # Optional: override auto-detection + # date="2023-06-01", # Optional: specify acquisition date + # gsd_override=10.0 # Optional: override ground sample distance + ) + print(" ✓ Image loaded and analyzed") + print(f" - Image shape: {clay.image.shape}") + print(f" - Detected sensor: {clay.sensor_type}") + print(f" - Center coordinates: ({clay.lat:.4f}, {clay.lon:.4f})") + except Exception as e: + print(f" ✗ Error loading image: {e}") + print(" Please check the image path and format") + return + + # Step 3: Generate embeddings + print("\n3. Generating Clay embeddings...") + try: + # For large images, process in tiles + embeddings_result = clay.generate_embeddings( + tile_size=256, # Size of processing tiles + overlap=0.1 # 10% overlap between tiles + ) + + print(" ✓ Embeddings generated successfully") + print(f" - Number of tiles: {embeddings_result['num_tiles']}") + print(f" - Embedding shape: {embeddings_result['embeddings'].shape}") + print(f" - Feature dimension: {embeddings_result['embeddings'].shape[-1]}") + + except Exception as e: + print(f" ✗ Error generating embeddings: {e}") + return + + # Step 4: Save embeddings + print("\n4. Saving embeddings...") + try: + embeddings_file = os.path.join(OUTPUT_DIR, "clay_embeddings.npz") + clay.save_embeddings(embeddings_result, embeddings_file, format='npz') + print(f" ✓ Embeddings saved to {embeddings_file}") + except Exception as e: + print(f" ✗ Error saving embeddings: {e}") + return + + # Step 5: Load and verify embeddings + print("\n5. Loading and verifying saved embeddings...") + try: + loaded_embeddings = load_embeddings(embeddings_file) + print(" ✓ Embeddings loaded successfully") + print(f" - Sensor type: {loaded_embeddings['sensor_type']}") + print(f" - Number of tiles: {loaded_embeddings['num_tiles']}") + print(f" - Original image shape: {loaded_embeddings['image_shape']}") + except Exception as e: + print(f" ✗ Error loading embeddings: {e}") + return + + # Step 6: Visualize results + print("\n6. Creating visualizations...") + try: + # Plot RGB image if available + fig, axes = plt.subplots(1, 2, figsize=(15, 6)) + + # Original image (RGB bands if available) + image = clay.image + if clay.sensor_type in clay.metadata: + rgb_indices = clay.metadata[clay.sensor_type].get('rgb_indices', [0, 1, 2]) + if len(rgb_indices) == 3 and image.shape[2] >= max(rgb_indices) + 1: + rgb_image = image[:, :, rgb_indices] + # Normalize for display + rgb_image = np.clip(rgb_image / np.percentile(rgb_image, 98), 0, 1) + axes[0].imshow(rgb_image) + axes[0].set_title(f'Original Image ({clay.sensor_type})') + axes[0].axis('off') + else: + axes[0].imshow(image[:, :, 0], cmap='gray') + axes[0].set_title('Original Image (First Band)') + axes[0].axis('off') + else: + axes[0].imshow(image[:, :, 0], cmap='gray') + axes[0].set_title('Original Image (First Band)') + axes[0].axis('off') + + # Embedding visualization (PCA of first tile) + embeddings = embeddings_result['embeddings'] + if embeddings.shape[0] > 0: + # Use first embedding for visualization + first_embedding = embeddings[0].flatten() + + # Create a simple visualization of embedding values + embedding_2d = first_embedding[:256].reshape(16, 16) # Take first 256 values + axes[1].imshow(embedding_2d, cmap='viridis') + axes[1].set_title('Clay Embedding Visualization\n(First 256 features, first tile)') + axes[1].axis('off') + + plt.tight_layout() + + # Save plot + plot_file = os.path.join(OUTPUT_DIR, "clay_embeddings_visualization.png") + plt.savefig(plot_file, dpi=150, bbox_inches='tight') + plt.show() + + print(f" ✓ Visualization saved to {plot_file}") + + except Exception as e: + print(f" ✗ Error creating visualizations: {e}") + + # Step 7: Demonstrate embedding analysis + print("\n7. Embedding analysis...") + try: + embeddings = embeddings_result['embeddings'] + + # Basic statistics + print(f" - Embedding statistics:") + print(f" * Mean: {np.mean(embeddings):.4f}") + print(f" * Std: {np.std(embeddings):.4f}") + print(f" * Min: {np.min(embeddings):.4f}") + print(f" * Max: {np.max(embeddings):.4f}") + + # Similarity between tiles (if multiple tiles) + if embeddings.shape[0] > 1: + from sklearn.metrics.pairwise import cosine_similarity + similarities = cosine_similarity(embeddings) + avg_similarity = np.mean(similarities[np.triu_indices_from(similarities, k=1)]) + print(f" * Average tile similarity: {avg_similarity:.4f}") + + print(" ✓ Analysis complete") + + except Exception as e: + print(f" ✗ Error in embedding analysis: {e}") + + print(f"\n=== Example completed successfully! ===") + print(f"Output files saved in: {OUTPUT_DIR}/") + print("\nNext steps:") + print("- Use embeddings for similarity search") + print("- Fine-tune on downstream tasks") + print("- Integrate with SAM for enhanced segmentation") + + +def example_with_numpy_array(): + """Example showing how to use Clay embeddings with numpy arrays.""" + print("\n=== Numpy Array Example ===") + + # Create a synthetic 4-band image (RGBI) + synthetic_image = np.random.randint(0, 255, (256, 256, 4), dtype=np.uint8) + + try: + # Initialize Clay model + clay = ClayEmbeddings( + checkpoint_path="path/to/clay-model-checkpoint.ckpt", + device="auto" + ) + + # Set synthetic image + clay.set_image( + source=synthetic_image, + sensor_type="naip", # Specify sensor type for numpy arrays + date="2023-06-01" + ) + + # Generate embeddings + result = clay.generate_embeddings(tile_size=256) + + print(f"Generated embeddings for synthetic image:") + print(f"- Shape: {result['embeddings'].shape}") + print(f"- Sensor: {result['sensor_type']}") + + except Exception as e: + print(f"Error in numpy array example: {e}") + + +if __name__ == "__main__": + main() + + # Uncomment to run numpy array example + # example_with_numpy_array() \ No newline at end of file diff --git a/samgeo/__init__.py b/samgeo/__init__.py index ffff87e1..9090796a 100644 --- a/samgeo/__init__.py +++ b/samgeo/__init__.py @@ -8,3 +8,4 @@ from .samgeo import * from .samgeo2 import * from .common import show_image_gui +from .clay import Clay, load_embeddings diff --git a/samgeo/clay.py b/samgeo/clay.py new file mode 100644 index 00000000..e396a5d5 --- /dev/null +++ b/samgeo/clay.py @@ -0,0 +1,555 @@ +""" +Clay foundation model wrapper for geospatial embeddings. + +This module provides a wrapper around the Clay foundation model for generating +rich spectral embeddings from geospatial imagery. It integrates with the +segment-geospatial library's raster I/O infrastructure. +""" + +import os +import math +import datetime +import numpy as np +import torch +import cv2 +import rasterio +import warnings +from typing import Optional, Union, Tuple, Dict, List, Any +from pathlib import Path + +try: + from claymodel.model import ClayMAEModule + from claymodel.utils import posemb_sincos_2d_with_gsd + from torchvision.transforms import v2 + import yaml + from box import Box + CLAY_AVAILABLE = True +except ImportError: + CLAY_AVAILABLE = False + +from .common import ( + check_file_path, + download_file, + transform_coords, + reproject, +) + + +# Default metadata for common sensors +DEFAULT_METADATA = { + 'sentinel-2-l2a': { + 'band_order': ['blue', 'green', 'red', 'rededge1', 'rededge2', 'rededge3', 'nir', 'nir08', 'swir16', 'swir22'], + 'rgb_indices': [2, 1, 0], + 'gsd': 10, + 'bands': { + 'mean': {'blue': 1105., 'green': 1355., 'red': 1552., 'rededge1': 1887., 'rededge2': 2422., 'rededge3': 2630., 'nir': 2743., 'nir08': 2785., 'swir16': 2388., 'swir22': 1835.}, + 'std': {'blue': 1809., 'green': 1757., 'red': 1888., 'rededge1': 1870., 'rededge2': 1732., 'rededge3': 1697., 'nir': 1742., 'nir08': 1648., 'swir16': 1470., 'swir22': 1379.}, + 'wavelength': {'blue': 0.493, 'green': 0.56, 'red': 0.665, 'rededge1': 0.704, 'rededge2': 0.74, 'rededge3': 0.783, 'nir': 0.842, 'nir08': 0.865, 'swir16': 1.61, 'swir22': 2.19} + } + }, + 'landsat-c2l2-sr': { + 'band_order': ['red', 'green', 'blue', 'nir08', 'swir16', 'swir22'], + 'rgb_indices': [0, 1, 2], + 'gsd': 30, + 'bands': { + 'mean': {'red': 13705., 'green': 13310., 'blue': 12474., 'nir08': 17801., 'swir16': 14615., 'swir22': 12701.}, + 'std': {'red': 9578., 'green': 9408., 'blue': 10144., 'nir08': 8277., 'swir16': 5300., 'swir22': 4522.}, + 'wavelength': {'red': 0.65, 'green': 0.56, 'blue': 0.48, 'nir08': 0.86, 'swir16': 1.6, 'swir22': 2.2} + } + }, + 'naip': { + 'band_order': ['red', 'green', 'blue', 'nir'], + 'rgb_indices': [0, 1, 2], + 'gsd': 1.0, + 'bands': { + 'mean': {'red': 110.16, 'green': 115.41, 'blue': 98.15, 'nir': 139.04}, + 'std': {'red': 47.23, 'green': 39.82, 'blue': 35.43, 'nir': 49.86}, + 'wavelength': {'red': 0.65, 'green': 0.56, 'blue': 0.48, 'nir': 0.842} + } + } +} + + +def normalize_timestamp(date): + """Normalize timestamp to week and hour components for Clay model.""" + if isinstance(date, str): + date = datetime.datetime.fromisoformat(date.replace('Z', '+00:00')) + elif not isinstance(date, datetime.datetime): + date = datetime.datetime.now() + + # Get day of year and hour + day_of_year = date.timetuple().tm_yday + hour = date.hour + + # Normalize to [-1, 1] range + week_norm = 2 * (day_of_year - 1) / 365 - 1 + hour_norm = 2 * hour / 24 - 1 + + return [week_norm, hour_norm] + + +def normalize_latlon(lat: float, lon: float) -> Tuple[List[float], List[float]]: + lat_rad = lat * np.pi / 180 + lon_rad = lon * np.pi / 180 + + lat_norm = [math.sin(lat_rad), math.cos(lat_rad)] + lon_norm = [math.sin(lon_rad), math.cos(lon_rad)] + + return lat_norm, lon_norm + + +class Clay: + """ + Clay foundation model wrapper for generating geospatial embeddings. + + This class provides an interface to generate rich spectral embeddings from + geospatial imagery using the Clay foundation model. + """ + + + + def __init__( + self, + checkpoint_path: str, + metadata_path: Optional[str] = None, + device: str = "auto", + ): + """ + Initialize Clay embeddings model. + + Args: + checkpoint_path: Path to Clay model checkpoint + metadata_path: Path to Clay metadata YAML file (optional) + device: Device to run model on ('auto', 'cuda', 'cpu') + mask_ratio: Masking ratio for model (0.0 for inference) + shuffle: Whether to shuffle patches + """ + if not CLAY_AVAILABLE: + raise ImportError( + "Clay model dependencies not available. " + "Please install: pip install claymodel torch torchvision pyyaml python-box" + ) + + self.checkpoint_path = check_file_path(checkpoint_path, make_dirs=False) + if not os.path.exists(self.checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}") + + # Set device + if device == "auto": + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = torch.device(device) + + # Load metadata + if metadata_path and os.path.exists(metadata_path): + with open(metadata_path, 'r') as f: + self.metadata = Box(yaml.safe_load(f)) + else: + self.metadata = Box(self.DEFAULT_METADATA) + if metadata_path: + warnings.warn(f"Metadata file not found: {metadata_path}. Using defaults.") + + # Load model + self._load_model() + + # Image processing attributes + self.image = None + self.source = None + self.sensor_type = None + self.raster_profile = None + + def _load_model(self): + """Load the Clay model from checkpoint.""" + try: + torch.set_default_device(self.device) + self.model = ClayMAEModule.load_from_checkpoint( + self.checkpoint_path, + shuffle=False, + mask_ratio=0.0 + ) + self.model.eval() + self.model = self.model.to(self.device) + except Exception as e: + raise RuntimeError(f"Failed to load Clay model: {e}") + + def _detect_sensor_type( + self, + src: rasterio.DatasetReader, + source_path: Optional[str] = None + ) -> str: + """ + Detect sensor type from raster metadata and characteristics. + + Args: + src: Rasterio dataset reader + source_path: Optional source file path for filename-based detection + + Returns: + Detected sensor type string + """ + band_count = src.count + resolution = abs(src.transform[0]) # Pixel size + + # Try filename-based detection first + if source_path: + filename = os.path.basename(source_path).lower() + if 'sentinel' in filename or 's2' in filename: + return 'sentinel-2-l2a' + elif 'landsat' in filename or 'l8' in filename or 'l9' in filename: + return 'landsat-c2l2-sr' + elif 'naip' in filename: + return 'naip' + + # Fallback to resolution and band count heuristics + if band_count == 4 and resolution <= 5: + return 'naip' # High-res 4-band imagery + elif band_count >= 6 and 25 <= resolution <= 35: + return 'landsat-c2l2-sr' # Landsat resolution + elif band_count >= 10 and 8 <= resolution <= 12: + return 'sentinel-2-l2a' # Sentinel-2 resolution + elif band_count == 4: + return 'naip' # Default 4-band to NAIP + else: + # Default fallback + warnings.warn( + f"Could not detect sensor type (bands: {band_count}, " + f"resolution: {resolution:.1f}m). Defaulting to NAIP." + ) + return 'naip' + + def _get_raster_center_latlon(self, src: rasterio.DatasetReader) -> Tuple[float, float]: + """Get the center lat/lon of the raster.""" + bounds = src.bounds + center_x = (bounds.left + bounds.right) / 2 + center_y = (bounds.bottom + bounds.top) / 2 + + # Transform to WGS84 if needed + if src.crs != 'EPSG:4326': + lon, lat = transform_coords( + [(center_x, center_y)], + src.crs, + 'EPSG:4326' + )[0] + else: + lon, lat = center_x, center_y + + return lat, lon + + def _prepare_datacube( + self, + image: np.ndarray, + sensor_type: str, + lat: float, + lon: float, + date: Optional[datetime.datetime] = None, + gsd_override: Optional[float] = None + ) -> Dict[str, torch.Tensor]: + """ + Prepare datacube for Clay model input. + + Args: + image: Input image array [H, W, C] + sensor_type: Detected sensor type + lat: Latitude of image center + lon: Longitude of image center + date: Image acquisition date + gsd_override: Override GSD value + + Returns: + Datacube dictionary for Clay model + """ + if date is None: + date = datetime.datetime.now() + + # Get sensor metadata + sensor_meta = self.metadata[sensor_type] + band_order = sensor_meta.band_order + gsd = gsd_override if gsd_override is not None else sensor_meta.gsd + + # Extract normalization parameters + means = [sensor_meta.bands.mean[band] for band in band_order] + stds = [sensor_meta.bands.std[band] for band in band_order] + wavelengths = [sensor_meta.bands.wavelength[band] for band in band_order] + + # Convert image to torch tensor and normalize + # Ensure we have the right number of bands + if image.shape[2] != len(band_order): + warnings.warn( + f"Image has {image.shape[2]} bands but sensor {sensor_type} " + f"expects {len(band_order)} bands. Using available bands." + ) + # Take only the available bands + num_bands = min(image.shape[2], len(band_order)) + image = image[:, :, :num_bands] + means = means[:num_bands] + stds = stds[:num_bands] + wavelengths = wavelengths[:num_bands] + + # Convert to tensor and transpose to [C, H, W] + pixels = torch.from_numpy(image.astype(np.float32)).permute(2, 0, 1) + + # Normalize + transform = v2.Compose([v2.Normalize(mean=means, std=stds)]) + pixels = transform(pixels).unsqueeze(0) # Add batch dimension + + # Prepare temporal encoding + time_norm = normalize_timestamp(date) + + # Prepare spatial encoding + lat_norm, lon_norm = normalize_latlon(lat, lon) + + # Create datacube + datacube = { + 'pixels': pixels.to(self.device), + 'time': torch.tensor( + time_norm + time_norm, # Clay expects 4 elements: [week, hour, week, hour] + dtype=torch.float32, + device=self.device + ).unsqueeze(0), + 'latlon': torch.tensor( + lat_norm + lon_norm, # Clay expects 4 elements: [sin_lat, cos_lat, sin_lon, cos_lon] + dtype=torch.float32, + device=self.device + ).unsqueeze(0), + 'gsd': torch.tensor(gsd, device=self.device), + 'waves': torch.tensor(wavelengths, device=self.device) + } + + return datacube + + def set_image( + self, + source: Union[str, np.ndarray], + sensor_type: Optional[str] = None, + date: Optional[Union[str, datetime.datetime]] = None, + gsd_override: Optional[float] = None + ): + """ + Set the input image for embedding generation. + + Args: + source: Path to image file or numpy array + sensor_type: Optional sensor type override + date: Image acquisition date + gsd_override: Override GSD value + """ + if isinstance(source, str): + if source.startswith("http"): + source = download_file(source) + + if not os.path.exists(source): + raise ValueError(f"Input path {source} does not exist.") + + # Read with rasterio for geospatial images + try: + with rasterio.open(source) as src: + # Read all bands + image = src.read() # Shape: [C, H, W] + image = np.transpose(image, (1, 2, 0)) # Convert to [H, W, C] + + # Store raster metadata + self.raster_profile = src.profile + + # Detect sensor type + if sensor_type is None: + sensor_type = self._detect_sensor_type(src, source) + + # Get image center coordinates + lat, lon = self._get_raster_center_latlon(src) + + except Exception: + # Fallback to OpenCV for regular images + image = cv2.imread(source) + if image is None: + raise ValueError(f"Could not read image: {source}") + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Use defaults for non-geospatial images + sensor_type = sensor_type or 'naip' + lat, lon = 0.0, 0.0 # Default coordinates + self.raster_profile = None + + elif isinstance(source, np.ndarray): + image = source + sensor_type = sensor_type or 'naip' + lat, lon = 0.0, 0.0 + self.raster_profile = None + + else: + raise ValueError("Source must be a file path or numpy array") + + # Parse date if string + if isinstance(date, str): + try: + date = datetime.datetime.fromisoformat(date.replace('Z', '+00:00')) + except ValueError: + date = datetime.datetime.now() + warnings.warn(f"Could not parse date: {date}. Using current time.") + elif date is None: + date = datetime.datetime.now() + + # Store image and metadata + self.source = source if isinstance(source, str) else None + self.image = image + self.sensor_type = sensor_type + self.lat = lat + self.lon = lon + self.date = date + self.gsd_override = gsd_override + + print(f"Set image: shape={image.shape}, sensor={sensor_type}, " + f"lat={lat:.4f}, lon={lon:.4f}") + + def generate_embeddings( + self, + tile_size: int = 256, + overlap: float = 0.0 + ) -> Dict[str, Any]: + """ + Generate embeddings for the loaded image. + + Args: + tile_size: Size of tiles for processing large images + overlap: Overlap fraction between tiles (0.0 to 1.0) + + Returns: + Dictionary containing embeddings and metadata + """ + if self.image is None: + raise ValueError("No image loaded. Call set_image() first.") + + image = self.image + h, w = image.shape[:2] + + # If image is smaller than tile_size, process as single tile + if h <= tile_size and w <= tile_size: + # Pad image to tile_size if needed + if h < tile_size or w < tile_size: + pad_h = max(0, tile_size - h) + pad_w = max(0, tile_size - w) + image = np.pad( + image, + ((0, pad_h), (0, pad_w), (0, 0)), + mode='reflect' + ) + + # Generate single embedding + datacube = self._prepare_datacube( + image, self.sensor_type, self.lat, self.lon, + self.date, self.gsd_override + ) + + with torch.no_grad(): + encoded_patches, _, _, _ = self.model.model.encoder(datacube) + # Extract class token (global embedding) + embedding = encoded_patches[:, 0, :].cpu().numpy() + + return { + 'embeddings': embedding, + 'tile_coords': [(0, 0, h, w)], + 'image_shape': (h, w), + 'sensor_type': self.sensor_type, + 'lat': self.lat, + 'lon': self.lon, + 'date': self.date.isoformat() if self.date else None, + 'num_tiles': 1 + } + + else: + # Process as overlapping tiles + step_size = int(tile_size * (1 - overlap)) + embeddings = [] + tile_coords = [] + + for y in range(0, h - tile_size + 1, step_size): + for x in range(0, w - tile_size + 1, step_size): + # Extract tile + tile = image[y:y+tile_size, x:x+tile_size] + + # Prepare datacube for this tile + datacube = self._prepare_datacube( + tile, self.sensor_type, self.lat, self.lon, + self.date, self.gsd_override + ) + + # Generate embedding + with torch.no_grad(): + encoded_patches, _, _, _ = self.model.model.encoder(datacube) + embedding = encoded_patches[:, 0, :].cpu().numpy() + + embeddings.append(embedding) + tile_coords.append((x, y, x+tile_size, y+tile_size)) + + return { + 'embeddings': np.vstack(embeddings), + 'tile_coords': tile_coords, + 'image_shape': (h, w), + 'sensor_type': self.sensor_type, + 'lat': self.lat, + 'lon': self.lon, + 'date': self.date.isoformat() if self.date else None, + 'num_tiles': len(embeddings) + } + + def save_embeddings( + self, + embeddings_result: Dict[str, Any], + output_path: str, + format: str = 'npz' + ): + """ + Save embeddings to file. + + Args: + embeddings_result: Result from generate_embeddings() + output_path: Output file path + format: Output format ('npz', 'pt') + """ + output_path = check_file_path(output_path) + + if format == 'npz': + np.savez_compressed( + output_path, + embeddings=embeddings_result['embeddings'], + tile_coords=np.array(embeddings_result['tile_coords']), + image_shape=np.array(embeddings_result['image_shape']), + sensor_type=embeddings_result['sensor_type'], + lat=embeddings_result['lat'], + lon=embeddings_result['lon'], + date=embeddings_result['date'], + num_tiles=embeddings_result['num_tiles'] + ) + elif format == 'pt': + torch.save(embeddings_result, output_path) + else: + raise ValueError(f"Unsupported format: {format}") + + print(f"Saved embeddings to {output_path}") + + +def load_embeddings(file_path: str) -> Dict[str, Any]: + """ + Load embeddings from file. + + Args: + file_path: Path to embeddings file + + Returns: + Embeddings dictionary + """ + if file_path.endswith('.npz'): + data = np.load(file_path, allow_pickle=True) + return { + 'embeddings': data['embeddings'], + 'tile_coords': data['tile_coords'].tolist(), + 'image_shape': tuple(data['image_shape']), + 'sensor_type': str(data['sensor_type']), + 'lat': float(data['lat']), + 'lon': float(data['lon']), + 'date': str(data['date']) if data['date'] != 'None' else None, + 'num_tiles': int(data['num_tiles']) + } + elif file_path.endswith('.pt'): + return torch.load(file_path, map_location='cpu') + else: + raise ValueError(f"Unsupported file format: {file_path}") \ No newline at end of file diff --git a/samgeo/clay_metadata.yaml b/samgeo/clay_metadata.yaml new file mode 100644 index 00000000..d18ebbae --- /dev/null +++ b/samgeo/clay_metadata.yaml @@ -0,0 +1,295 @@ +sentinel-2-l2a: + band_order: + - blue + - green + - red + - rededge1 + - rededge2 + - rededge3 + - nir + - nir08 + - swir16 + - swir22 + rgb_indices: + - 2 + - 1 + - 0 + gsd: 10 + bands: + mean: + blue: 1105. + green: 1355. + red: 1552. + rededge1: 1887. + rededge2: 2422. + rededge3: 2630. + nir: 2743. + nir08: 2785. + swir16: 2388. + swir22: 1835. + std: + blue: 1809. + green: 1757. + red: 1888. + rededge1: 1870. + rededge2: 1732. + rededge3: 1697. + nir: 1742. + nir08: 1648. + swir16: 1470. + swir22: 1379. + wavelength: + blue: 0.493 + green: 0.56 + red: 0.665 + rededge1: 0.704 + rededge2: 0.74 + rededge3: 0.783 + nir: 0.842 + nir08: 0.865 + swir16: 1.61 + swir22: 2.19 +planetscope-sr: + band_order: + - coastal_blue + - blue + - green_i + - green + - yellow + - red + - rededge + - nir + rgb_indices: + - 5 + - 3 + - 1 + gsd: 5 + bands: + mean: + coastal_blue: 1720. + blue: 1715. + green_i: 1913. + green: 2088. + yellow: 2274. + red: 2290. + rededge: 2613. + nir: 3970. + std: + coastal_blue: 747. + blue: 698. + green_i: 739. + green: 768. + yellow: 849. + red: 868. + rededge: 849. + nir: 914. + wavelength: + coastal_blue: 0.443 + blue: 0.490 + green_i: 0.531 + green: 0.565 + yellow: 0.610 + red: 0.665 + rededge: 0.705 + nir: 0.865 +landsat-c2l1: + band_order: + - red + - green + - blue + - nir08 + - swir16 + - swir22 + rgb_indices: + - 0 + - 1 + - 2 + gsd: 30 + bands: + mean: + red: 10678. + green: 10563. + blue: 11083. + nir08: 14792. + swir16: 12276. + swir22: 10114. + std: + red: 6025. + green: 5411. + blue: 5468. + nir08: 6746. + swir16: 5897. + swir22: 4850. + wavelength: + red: 0.65 + green: 0.56 + blue: 0.48 + nir08: 0.86 + swir16: 1.6 + swir22: 2.2 +landsat-c2l2-sr: + band_order: + - red + - green + - blue + - nir08 + - swir16 + - swir22 + rgb_indices: + - 0 + - 1 + - 2 + gsd: 30 + bands: + mean: + red: 13705. + green: 13310. + blue: 12474. + nir08: 17801. + swir16: 14615. + swir22: 12701. + std: + red: 9578. + green: 9408. + blue: 10144. + nir08: 8277. + swir16: 5300. + swir22: 4522. + wavelength: + red: 0.65 + green: 0.56 + blue: 0.48 + nir08: 0.86 + swir16: 1.6 + swir22: 2.2 +naip: + band_order: + - red + - green + - blue + - nir + rgb_indices: + - 0 + - 1 + - 2 + gsd: 1.0 + bands: + mean: + red: 110.16 + green: 115.41 + blue: 98.15 + nir: 139.04 + std: + red: 47.23 + green: 39.82 + blue: 35.43 + nir: 49.86 + wavelength: + red: 0.65 + green: 0.56 + blue: 0.48 + nir: 0.842 +linz: + band_order: + - red + - green + - blue + rgb_indices: + - 0 + - 1 + - 2 + gsd: 0.5 + bands: + mean: + red: 89.96 + green: 99.46 + blue: 89.51 + std: + red: 41.83 + green: 36.96 + blue: 31.45 + wavelength: + red: 0.635 + green: 0.555 + blue: 0.465 +sentinel-1-rtc: + band_order: + - vv + - vh + gsd: 10 + bands: + mean: + vv: -12.113 + vh: -18.673 + std: + vv: 8.314 + vh: 8.017 + wavelength: + vv: 3.5 + vh: 4.0 +modis: + band_order: + - sur_refl_b01 + - sur_refl_b02 + - sur_refl_b03 + - sur_refl_b04 + - sur_refl_b05 + - sur_refl_b06 + - sur_refl_b07 + rgb_indices: + - 0 + - 3 + - 2 + gsd: 500 + bands: + mean: + sur_refl_b01: 1072. + sur_refl_b02: 1624. + sur_refl_b03: 931. + sur_refl_b04: 1023. + sur_refl_b05: 1599. + sur_refl_b06: 1404. + sur_refl_b07: 1051. + std: + sur_refl_b01: 1643. + sur_refl_b02: 1878. + sur_refl_b03: 1449. + sur_refl_b04: 1538. + sur_refl_b05: 1763. + sur_refl_b06: 1618. + sur_refl_b07: 1396. + wavelength: + sur_refl_b01: .645 + sur_refl_b02: .858 + sur_refl_b03: .469 + sur_refl_b04: .555 + sur_refl_b05: 1.240 + sur_refl_b06: 1.640 + sur_refl_b07: 2.130 +satellogic-MSI-L1D: + band_order: + - red + - green + - blue + - nir + rgb_indices: + - 0 + - 1 + - 2 + gsd: 1.0 + bands: + mean: + red: 1451.54 + green: 1456.54 + blue: 1543.22 + nir: 2132.68 + std: + red: 995.48 + green: 771.29 + blue: 708.86 + nir: 1236.71 + wavelength: + red: 0.640 + green: 0.545 + blue: 0.480 + nir: 0.825 From c5cb11ac1908b6332aec376b414cd6704d59035d Mon Sep 17 00:00:00 2001 From: Pratik-Doshi-99 Date: Wed, 23 Jul 2025 00:34:12 +0000 Subject: [PATCH 2/3] added model size and fixed time and space normalization --- samgeo/clay.py | 62 ++++++++++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/samgeo/clay.py b/samgeo/clay.py index e396a5d5..7bc61982 100644 --- a/samgeo/clay.py +++ b/samgeo/clay.py @@ -70,32 +70,24 @@ } + def normalize_timestamp(date): - """Normalize timestamp to week and hour components for Clay model.""" - if isinstance(date, str): - date = datetime.datetime.fromisoformat(date.replace('Z', '+00:00')) - elif not isinstance(date, datetime.datetime): - date = datetime.datetime.now() - - # Get day of year and hour - day_of_year = date.timetuple().tm_yday - hour = date.hour - - # Normalize to [-1, 1] range - week_norm = 2 * (day_of_year - 1) / 365 - 1 - hour_norm = 2 * hour / 24 - 1 - - return [week_norm, hour_norm] + """Normaize the timestamp for clay. Taken from https://github.com/Clay-foundation/stacchip/blob/main/stacchip/processors/prechip.py""" + week = date.isocalendar().week * 2 * np.pi / 52 + hour = date.hour * 2 * np.pi / 24 + return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour)) -def normalize_latlon(lat: float, lon: float) -> Tuple[List[float], List[float]]: - lat_rad = lat * np.pi / 180 - lon_rad = lon * np.pi / 180 - - lat_norm = [math.sin(lat_rad), math.cos(lat_rad)] - lon_norm = [math.sin(lon_rad), math.cos(lon_rad)] - - return lat_norm, lon_norm + +def normalize_latlon(bounds): + """Normalize latitude and longitude for clay. Taken from https://github.com/Clay-foundation/stacchip/blob/main/stacchip/processors/prechip.py""" + lon = bounds[0] + (bounds[2] - bounds[0]) / 2 + lat = bounds[1] + (bounds[3] - bounds[1]) / 2 + + lat = lat * np.pi / 180 + lon = lon * np.pi / 180 + + return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon)) class Clay: @@ -111,6 +103,7 @@ class Clay: def __init__( self, checkpoint_path: str, + model_size: str = 'large', metadata_path: Optional[str] = None, device: str = "auto", ): @@ -149,6 +142,13 @@ def __init__( if metadata_path: warnings.warn(f"Metadata file not found: {metadata_path}. Using defaults.") + + self.model_size = model_size + if self.model_size not in ['tiny','small','base','large']: + raise ValueError(f"model_size must be one of: {['tiny','small','base','large']}") + + + # Load model self._load_model() @@ -162,13 +162,15 @@ def _load_model(self): """Load the Clay model from checkpoint.""" try: torch.set_default_device(self.device) - self.model = ClayMAEModule.load_from_checkpoint( - self.checkpoint_path, + self.module = ClayMAEModule.load_from_checkpoint( + checkpoint_path=self.checkpoint_path, + model_size=self.model_size, + dolls=[16, 32, 64, 128, 256, 768, 1024], + doll_weights=[1, 1, 1, 1, 1, 1, 1], + mask_ratio=0.0, shuffle=False, - mask_ratio=0.0 ) - self.model.eval() - self.model = self.model.to(self.device) + self.module.eval() except Exception as e: raise RuntimeError(f"Failed to load Clay model: {e}") @@ -440,7 +442,7 @@ def generate_embeddings( ) with torch.no_grad(): - encoded_patches, _, _, _ = self.model.model.encoder(datacube) + encoded_patches, _, _, _ = self.module.model.encoder(datacube) # Extract class token (global embedding) embedding = encoded_patches[:, 0, :].cpu().numpy() @@ -474,7 +476,7 @@ def generate_embeddings( # Generate embedding with torch.no_grad(): - encoded_patches, _, _, _ = self.model.model.encoder(datacube) + encoded_patches, _, _, _ = self.module.model.encoder(datacube) embedding = encoded_patches[:, 0, :].cpu().numpy() embeddings.append(embedding) From a9e56ac99903db98499fd5a5d4391cb906605006 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Jul 2025 17:38:11 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/clay_embeddings_example.py | 122 ++++---- samgeo/clay.py | 467 ++++++++++++++++------------ 2 files changed, 330 insertions(+), 259 deletions(-) diff --git a/examples/clay_embeddings_example.py b/examples/clay_embeddings_example.py index 920deac3..1a3c6bc5 100644 --- a/examples/clay_embeddings_example.py +++ b/examples/clay_embeddings_example.py @@ -22,14 +22,14 @@ def main(): # Configuration CHECKPOINT_PATH = "path/to/clay-model-checkpoint.ckpt" # Update this path - IMAGE_PATH = "path/to/your/satellite_image.tif" # Update this path + IMAGE_PATH = "path/to/your/satellite_image.tif" # Update this path OUTPUT_DIR = "clay_embeddings_output" - + # Create output directory os.makedirs(OUTPUT_DIR, exist_ok=True) - + print("=== Clay Foundation Model Embeddings Example ===\n") - + # Step 1: Initialize Clay embeddings model print("1. Initializing Clay model...") try: @@ -37,16 +37,18 @@ def main(): checkpoint_path=CHECKPOINT_PATH, device="auto", # Will use GPU if available mask_ratio=0.0, # No masking for inference - shuffle=False + shuffle=False, ) print(" ✓ Clay model loaded successfully") except Exception as e: print(f" ✗ Error loading Clay model: {e}") print(" Please ensure you have:") print(" - Valid Clay checkpoint file") - print(" - Clay dependencies: pip install claymodel torch torchvision pyyaml python-box") + print( + " - Clay dependencies: pip install claymodel torch torchvision pyyaml python-box" + ) return - + # Step 2: Load and analyze image print("\n2. Loading geospatial image...") try: @@ -64,35 +66,35 @@ def main(): print(f" ✗ Error loading image: {e}") print(" Please check the image path and format") return - + # Step 3: Generate embeddings print("\n3. Generating Clay embeddings...") try: # For large images, process in tiles embeddings_result = clay.generate_embeddings( - tile_size=256, # Size of processing tiles - overlap=0.1 # 10% overlap between tiles + tile_size=256, # Size of processing tiles + overlap=0.1, # 10% overlap between tiles ) - + print(" ✓ Embeddings generated successfully") print(f" - Number of tiles: {embeddings_result['num_tiles']}") print(f" - Embedding shape: {embeddings_result['embeddings'].shape}") print(f" - Feature dimension: {embeddings_result['embeddings'].shape[-1]}") - + except Exception as e: print(f" ✗ Error generating embeddings: {e}") return - + # Step 4: Save embeddings print("\n4. Saving embeddings...") try: embeddings_file = os.path.join(OUTPUT_DIR, "clay_embeddings.npz") - clay.save_embeddings(embeddings_result, embeddings_file, format='npz') + clay.save_embeddings(embeddings_result, embeddings_file, format="npz") print(f" ✓ Embeddings saved to {embeddings_file}") except Exception as e: print(f" ✗ Error saving embeddings: {e}") return - + # Step 5: Load and verify embeddings print("\n5. Loading and verifying saved embeddings...") try: @@ -104,81 +106,88 @@ def main(): except Exception as e: print(f" ✗ Error loading embeddings: {e}") return - + # Step 6: Visualize results print("\n6. Creating visualizations...") try: # Plot RGB image if available fig, axes = plt.subplots(1, 2, figsize=(15, 6)) - + # Original image (RGB bands if available) image = clay.image if clay.sensor_type in clay.metadata: - rgb_indices = clay.metadata[clay.sensor_type].get('rgb_indices', [0, 1, 2]) + rgb_indices = clay.metadata[clay.sensor_type].get("rgb_indices", [0, 1, 2]) if len(rgb_indices) == 3 and image.shape[2] >= max(rgb_indices) + 1: rgb_image = image[:, :, rgb_indices] # Normalize for display rgb_image = np.clip(rgb_image / np.percentile(rgb_image, 98), 0, 1) axes[0].imshow(rgb_image) - axes[0].set_title(f'Original Image ({clay.sensor_type})') - axes[0].axis('off') + axes[0].set_title(f"Original Image ({clay.sensor_type})") + axes[0].axis("off") else: - axes[0].imshow(image[:, :, 0], cmap='gray') - axes[0].set_title('Original Image (First Band)') - axes[0].axis('off') + axes[0].imshow(image[:, :, 0], cmap="gray") + axes[0].set_title("Original Image (First Band)") + axes[0].axis("off") else: - axes[0].imshow(image[:, :, 0], cmap='gray') - axes[0].set_title('Original Image (First Band)') - axes[0].axis('off') - + axes[0].imshow(image[:, :, 0], cmap="gray") + axes[0].set_title("Original Image (First Band)") + axes[0].axis("off") + # Embedding visualization (PCA of first tile) - embeddings = embeddings_result['embeddings'] + embeddings = embeddings_result["embeddings"] if embeddings.shape[0] > 0: # Use first embedding for visualization first_embedding = embeddings[0].flatten() - + # Create a simple visualization of embedding values - embedding_2d = first_embedding[:256].reshape(16, 16) # Take first 256 values - axes[1].imshow(embedding_2d, cmap='viridis') - axes[1].set_title('Clay Embedding Visualization\n(First 256 features, first tile)') - axes[1].axis('off') - + embedding_2d = first_embedding[:256].reshape( + 16, 16 + ) # Take first 256 values + axes[1].imshow(embedding_2d, cmap="viridis") + axes[1].set_title( + "Clay Embedding Visualization\n(First 256 features, first tile)" + ) + axes[1].axis("off") + plt.tight_layout() - + # Save plot plot_file = os.path.join(OUTPUT_DIR, "clay_embeddings_visualization.png") - plt.savefig(plot_file, dpi=150, bbox_inches='tight') + plt.savefig(plot_file, dpi=150, bbox_inches="tight") plt.show() - + print(f" ✓ Visualization saved to {plot_file}") - + except Exception as e: print(f" ✗ Error creating visualizations: {e}") - + # Step 7: Demonstrate embedding analysis print("\n7. Embedding analysis...") try: - embeddings = embeddings_result['embeddings'] - + embeddings = embeddings_result["embeddings"] + # Basic statistics print(f" - Embedding statistics:") print(f" * Mean: {np.mean(embeddings):.4f}") print(f" * Std: {np.std(embeddings):.4f}") print(f" * Min: {np.min(embeddings):.4f}") print(f" * Max: {np.max(embeddings):.4f}") - + # Similarity between tiles (if multiple tiles) if embeddings.shape[0] > 1: from sklearn.metrics.pairwise import cosine_similarity + similarities = cosine_similarity(embeddings) - avg_similarity = np.mean(similarities[np.triu_indices_from(similarities, k=1)]) + avg_similarity = np.mean( + similarities[np.triu_indices_from(similarities, k=1)] + ) print(f" * Average tile similarity: {avg_similarity:.4f}") - + print(" ✓ Analysis complete") - + except Exception as e: print(f" ✗ Error in embedding analysis: {e}") - + print(f"\n=== Example completed successfully! ===") print(f"Output files saved in: {OUTPUT_DIR}/") print("\nNext steps:") @@ -190,37 +199,36 @@ def main(): def example_with_numpy_array(): """Example showing how to use Clay embeddings with numpy arrays.""" print("\n=== Numpy Array Example ===") - + # Create a synthetic 4-band image (RGBI) synthetic_image = np.random.randint(0, 255, (256, 256, 4), dtype=np.uint8) - + try: # Initialize Clay model clay = ClayEmbeddings( - checkpoint_path="path/to/clay-model-checkpoint.ckpt", - device="auto" + checkpoint_path="path/to/clay-model-checkpoint.ckpt", device="auto" ) - + # Set synthetic image clay.set_image( source=synthetic_image, sensor_type="naip", # Specify sensor type for numpy arrays - date="2023-06-01" + date="2023-06-01", ) - + # Generate embeddings result = clay.generate_embeddings(tile_size=256) - + print(f"Generated embeddings for synthetic image:") print(f"- Shape: {result['embeddings'].shape}") print(f"- Sensor: {result['sensor_type']}") - + except Exception as e: print(f"Error in numpy array example: {e}") if __name__ == "__main__": main() - + # Uncomment to run numpy array example - # example_with_numpy_array() \ No newline at end of file + # example_with_numpy_array() diff --git a/samgeo/clay.py b/samgeo/clay.py index 7bc61982..c07ca4ef 100644 --- a/samgeo/clay.py +++ b/samgeo/clay.py @@ -2,7 +2,7 @@ Clay foundation model wrapper for geospatial embeddings. This module provides a wrapper around the Clay foundation model for generating -rich spectral embeddings from geospatial imagery. It integrates with the +rich spectral embeddings from geospatial imagery. It integrates with the segment-geospatial library's raster I/O infrastructure. """ @@ -23,6 +23,7 @@ from torchvision.transforms import v2 import yaml from box import Box + CLAY_AVAILABLE = True except ImportError: CLAY_AVAILABLE = False @@ -37,40 +38,104 @@ # Default metadata for common sensors DEFAULT_METADATA = { - 'sentinel-2-l2a': { - 'band_order': ['blue', 'green', 'red', 'rededge1', 'rededge2', 'rededge3', 'nir', 'nir08', 'swir16', 'swir22'], - 'rgb_indices': [2, 1, 0], - 'gsd': 10, - 'bands': { - 'mean': {'blue': 1105., 'green': 1355., 'red': 1552., 'rededge1': 1887., 'rededge2': 2422., 'rededge3': 2630., 'nir': 2743., 'nir08': 2785., 'swir16': 2388., 'swir22': 1835.}, - 'std': {'blue': 1809., 'green': 1757., 'red': 1888., 'rededge1': 1870., 'rededge2': 1732., 'rededge3': 1697., 'nir': 1742., 'nir08': 1648., 'swir16': 1470., 'swir22': 1379.}, - 'wavelength': {'blue': 0.493, 'green': 0.56, 'red': 0.665, 'rededge1': 0.704, 'rededge2': 0.74, 'rededge3': 0.783, 'nir': 0.842, 'nir08': 0.865, 'swir16': 1.61, 'swir22': 2.19} - } + "sentinel-2-l2a": { + "band_order": [ + "blue", + "green", + "red", + "rededge1", + "rededge2", + "rededge3", + "nir", + "nir08", + "swir16", + "swir22", + ], + "rgb_indices": [2, 1, 0], + "gsd": 10, + "bands": { + "mean": { + "blue": 1105.0, + "green": 1355.0, + "red": 1552.0, + "rededge1": 1887.0, + "rededge2": 2422.0, + "rededge3": 2630.0, + "nir": 2743.0, + "nir08": 2785.0, + "swir16": 2388.0, + "swir22": 1835.0, + }, + "std": { + "blue": 1809.0, + "green": 1757.0, + "red": 1888.0, + "rededge1": 1870.0, + "rededge2": 1732.0, + "rededge3": 1697.0, + "nir": 1742.0, + "nir08": 1648.0, + "swir16": 1470.0, + "swir22": 1379.0, + }, + "wavelength": { + "blue": 0.493, + "green": 0.56, + "red": 0.665, + "rededge1": 0.704, + "rededge2": 0.74, + "rededge3": 0.783, + "nir": 0.842, + "nir08": 0.865, + "swir16": 1.61, + "swir22": 2.19, + }, + }, }, - 'landsat-c2l2-sr': { - 'band_order': ['red', 'green', 'blue', 'nir08', 'swir16', 'swir22'], - 'rgb_indices': [0, 1, 2], - 'gsd': 30, - 'bands': { - 'mean': {'red': 13705., 'green': 13310., 'blue': 12474., 'nir08': 17801., 'swir16': 14615., 'swir22': 12701.}, - 'std': {'red': 9578., 'green': 9408., 'blue': 10144., 'nir08': 8277., 'swir16': 5300., 'swir22': 4522.}, - 'wavelength': {'red': 0.65, 'green': 0.56, 'blue': 0.48, 'nir08': 0.86, 'swir16': 1.6, 'swir22': 2.2} - } + "landsat-c2l2-sr": { + "band_order": ["red", "green", "blue", "nir08", "swir16", "swir22"], + "rgb_indices": [0, 1, 2], + "gsd": 30, + "bands": { + "mean": { + "red": 13705.0, + "green": 13310.0, + "blue": 12474.0, + "nir08": 17801.0, + "swir16": 14615.0, + "swir22": 12701.0, + }, + "std": { + "red": 9578.0, + "green": 9408.0, + "blue": 10144.0, + "nir08": 8277.0, + "swir16": 5300.0, + "swir22": 4522.0, + }, + "wavelength": { + "red": 0.65, + "green": 0.56, + "blue": 0.48, + "nir08": 0.86, + "swir16": 1.6, + "swir22": 2.2, + }, + }, + }, + "naip": { + "band_order": ["red", "green", "blue", "nir"], + "rgb_indices": [0, 1, 2], + "gsd": 1.0, + "bands": { + "mean": {"red": 110.16, "green": 115.41, "blue": 98.15, "nir": 139.04}, + "std": {"red": 47.23, "green": 39.82, "blue": 35.43, "nir": 49.86}, + "wavelength": {"red": 0.65, "green": 0.56, "blue": 0.48, "nir": 0.842}, + }, }, - 'naip': { - 'band_order': ['red', 'green', 'blue', 'nir'], - 'rgb_indices': [0, 1, 2], - 'gsd': 1.0, - 'bands': { - 'mean': {'red': 110.16, 'green': 115.41, 'blue': 98.15, 'nir': 139.04}, - 'std': {'red': 47.23, 'green': 39.82, 'blue': 35.43, 'nir': 49.86}, - 'wavelength': {'red': 0.65, 'green': 0.56, 'blue': 0.48, 'nir': 0.842} - } - } } - def normalize_timestamp(date): """Normaize the timestamp for clay. Taken from https://github.com/Clay-foundation/stacchip/blob/main/stacchip/processors/prechip.py""" week = date.isocalendar().week * 2 * np.pi / 52 @@ -93,23 +158,21 @@ def normalize_latlon(bounds): class Clay: """ Clay foundation model wrapper for generating geospatial embeddings. - + This class provides an interface to generate rich spectral embeddings from geospatial imagery using the Clay foundation model. """ - - - + def __init__( self, checkpoint_path: str, - model_size: str = 'large', + model_size: str = "large", metadata_path: Optional[str] = None, device: str = "auto", ): """ Initialize Clay embeddings model. - + Args: checkpoint_path: Path to Clay model checkpoint metadata_path: Path to Clay metadata YAML file (optional) @@ -122,42 +185,43 @@ def __init__( "Clay model dependencies not available. " "Please install: pip install claymodel torch torchvision pyyaml python-box" ) - + self.checkpoint_path = check_file_path(checkpoint_path, make_dirs=False) if not os.path.exists(self.checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}") - + # Set device if device == "auto": self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) - + # Load metadata if metadata_path and os.path.exists(metadata_path): - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: self.metadata = Box(yaml.safe_load(f)) else: self.metadata = Box(self.DEFAULT_METADATA) if metadata_path: - warnings.warn(f"Metadata file not found: {metadata_path}. Using defaults.") - - - self.model_size = model_size - if self.model_size not in ['tiny','small','base','large']: - raise ValueError(f"model_size must be one of: {['tiny','small','base','large']}") + warnings.warn( + f"Metadata file not found: {metadata_path}. Using defaults." + ) - + self.model_size = model_size + if self.model_size not in ["tiny", "small", "base", "large"]: + raise ValueError( + f"model_size must be one of: {['tiny','small','base','large']}" + ) # Load model self._load_model() - + # Image processing attributes self.image = None self.source = None self.sensor_type = None self.raster_profile = None - + def _load_model(self): """Load the Clay model from checkpoint.""" try: @@ -173,106 +237,102 @@ def _load_model(self): self.module.eval() except Exception as e: raise RuntimeError(f"Failed to load Clay model: {e}") - + def _detect_sensor_type( - self, - src: rasterio.DatasetReader, - source_path: Optional[str] = None + self, src: rasterio.DatasetReader, source_path: Optional[str] = None ) -> str: """ Detect sensor type from raster metadata and characteristics. - + Args: src: Rasterio dataset reader source_path: Optional source file path for filename-based detection - + Returns: Detected sensor type string """ band_count = src.count resolution = abs(src.transform[0]) # Pixel size - + # Try filename-based detection first if source_path: filename = os.path.basename(source_path).lower() - if 'sentinel' in filename or 's2' in filename: - return 'sentinel-2-l2a' - elif 'landsat' in filename or 'l8' in filename or 'l9' in filename: - return 'landsat-c2l2-sr' - elif 'naip' in filename: - return 'naip' - + if "sentinel" in filename or "s2" in filename: + return "sentinel-2-l2a" + elif "landsat" in filename or "l8" in filename or "l9" in filename: + return "landsat-c2l2-sr" + elif "naip" in filename: + return "naip" + # Fallback to resolution and band count heuristics if band_count == 4 and resolution <= 5: - return 'naip' # High-res 4-band imagery + return "naip" # High-res 4-band imagery elif band_count >= 6 and 25 <= resolution <= 35: - return 'landsat-c2l2-sr' # Landsat resolution + return "landsat-c2l2-sr" # Landsat resolution elif band_count >= 10 and 8 <= resolution <= 12: - return 'sentinel-2-l2a' # Sentinel-2 resolution + return "sentinel-2-l2a" # Sentinel-2 resolution elif band_count == 4: - return 'naip' # Default 4-band to NAIP + return "naip" # Default 4-band to NAIP else: # Default fallback warnings.warn( f"Could not detect sensor type (bands: {band_count}, " f"resolution: {resolution:.1f}m). Defaulting to NAIP." ) - return 'naip' - - def _get_raster_center_latlon(self, src: rasterio.DatasetReader) -> Tuple[float, float]: + return "naip" + + def _get_raster_center_latlon( + self, src: rasterio.DatasetReader + ) -> Tuple[float, float]: """Get the center lat/lon of the raster.""" bounds = src.bounds center_x = (bounds.left + bounds.right) / 2 center_y = (bounds.bottom + bounds.top) / 2 - + # Transform to WGS84 if needed - if src.crs != 'EPSG:4326': - lon, lat = transform_coords( - [(center_x, center_y)], - src.crs, - 'EPSG:4326' - )[0] + if src.crs != "EPSG:4326": + lon, lat = transform_coords([(center_x, center_y)], src.crs, "EPSG:4326")[0] else: lon, lat = center_x, center_y - + return lat, lon - + def _prepare_datacube( - self, - image: np.ndarray, + self, + image: np.ndarray, sensor_type: str, - lat: float, - lon: float, + lat: float, + lon: float, date: Optional[datetime.datetime] = None, - gsd_override: Optional[float] = None + gsd_override: Optional[float] = None, ) -> Dict[str, torch.Tensor]: """ Prepare datacube for Clay model input. - + Args: image: Input image array [H, W, C] sensor_type: Detected sensor type lat: Latitude of image center - lon: Longitude of image center + lon: Longitude of image center date: Image acquisition date gsd_override: Override GSD value - + Returns: Datacube dictionary for Clay model """ if date is None: date = datetime.datetime.now() - + # Get sensor metadata sensor_meta = self.metadata[sensor_type] band_order = sensor_meta.band_order gsd = gsd_override if gsd_override is not None else sensor_meta.gsd - + # Extract normalization parameters means = [sensor_meta.bands.mean[band] for band in band_order] stds = [sensor_meta.bands.std[band] for band in band_order] wavelengths = [sensor_meta.bands.wavelength[band] for band in band_order] - + # Convert image to torch tensor and normalize # Ensure we have the right number of bands if image.shape[2] != len(band_order): @@ -286,49 +346,51 @@ def _prepare_datacube( means = means[:num_bands] stds = stds[:num_bands] wavelengths = wavelengths[:num_bands] - + # Convert to tensor and transpose to [C, H, W] pixels = torch.from_numpy(image.astype(np.float32)).permute(2, 0, 1) - + # Normalize transform = v2.Compose([v2.Normalize(mean=means, std=stds)]) pixels = transform(pixels).unsqueeze(0) # Add batch dimension - + # Prepare temporal encoding time_norm = normalize_timestamp(date) - + # Prepare spatial encoding lat_norm, lon_norm = normalize_latlon(lat, lon) - + # Create datacube datacube = { - 'pixels': pixels.to(self.device), - 'time': torch.tensor( - time_norm + time_norm, # Clay expects 4 elements: [week, hour, week, hour] + "pixels": pixels.to(self.device), + "time": torch.tensor( + time_norm + + time_norm, # Clay expects 4 elements: [week, hour, week, hour] dtype=torch.float32, - device=self.device + device=self.device, ).unsqueeze(0), - 'latlon': torch.tensor( - lat_norm + lon_norm, # Clay expects 4 elements: [sin_lat, cos_lat, sin_lon, cos_lon] + "latlon": torch.tensor( + lat_norm + + lon_norm, # Clay expects 4 elements: [sin_lat, cos_lat, sin_lon, cos_lon] dtype=torch.float32, - device=self.device + device=self.device, ).unsqueeze(0), - 'gsd': torch.tensor(gsd, device=self.device), - 'waves': torch.tensor(wavelengths, device=self.device) + "gsd": torch.tensor(gsd, device=self.device), + "waves": torch.tensor(wavelengths, device=self.device), } - + return datacube - + def set_image( - self, + self, source: Union[str, np.ndarray], sensor_type: Optional[str] = None, date: Optional[Union[str, datetime.datetime]] = None, - gsd_override: Optional[float] = None + gsd_override: Optional[float] = None, ): """ Set the input image for embedding generation. - + Args: source: Path to image file or numpy array sensor_type: Optional sensor type override @@ -338,58 +400,58 @@ def set_image( if isinstance(source, str): if source.startswith("http"): source = download_file(source) - + if not os.path.exists(source): raise ValueError(f"Input path {source} does not exist.") - + # Read with rasterio for geospatial images try: with rasterio.open(source) as src: # Read all bands image = src.read() # Shape: [C, H, W] image = np.transpose(image, (1, 2, 0)) # Convert to [H, W, C] - + # Store raster metadata self.raster_profile = src.profile - + # Detect sensor type if sensor_type is None: sensor_type = self._detect_sensor_type(src, source) - + # Get image center coordinates lat, lon = self._get_raster_center_latlon(src) - + except Exception: # Fallback to OpenCV for regular images image = cv2.imread(source) if image is None: raise ValueError(f"Could not read image: {source}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - + # Use defaults for non-geospatial images - sensor_type = sensor_type or 'naip' + sensor_type = sensor_type or "naip" lat, lon = 0.0, 0.0 # Default coordinates self.raster_profile = None - + elif isinstance(source, np.ndarray): image = source - sensor_type = sensor_type or 'naip' + sensor_type = sensor_type or "naip" lat, lon = 0.0, 0.0 self.raster_profile = None - + else: raise ValueError("Source must be a file path or numpy array") - + # Parse date if string if isinstance(date, str): try: - date = datetime.datetime.fromisoformat(date.replace('Z', '+00:00')) + date = datetime.datetime.fromisoformat(date.replace("Z", "+00:00")) except ValueError: date = datetime.datetime.now() warnings.warn(f"Could not parse date: {date}. Using current time.") elif date is None: date = datetime.datetime.now() - + # Store image and metadata self.source = source if isinstance(source, str) else None self.image = image @@ -398,160 +460,161 @@ def set_image( self.lon = lon self.date = date self.gsd_override = gsd_override - - print(f"Set image: shape={image.shape}, sensor={sensor_type}, " - f"lat={lat:.4f}, lon={lon:.4f}") - + + print( + f"Set image: shape={image.shape}, sensor={sensor_type}, " + f"lat={lat:.4f}, lon={lon:.4f}" + ) + def generate_embeddings( - self, - tile_size: int = 256, - overlap: float = 0.0 + self, tile_size: int = 256, overlap: float = 0.0 ) -> Dict[str, Any]: """ Generate embeddings for the loaded image. - + Args: tile_size: Size of tiles for processing large images overlap: Overlap fraction between tiles (0.0 to 1.0) - + Returns: Dictionary containing embeddings and metadata """ if self.image is None: raise ValueError("No image loaded. Call set_image() first.") - + image = self.image h, w = image.shape[:2] - + # If image is smaller than tile_size, process as single tile if h <= tile_size and w <= tile_size: # Pad image to tile_size if needed if h < tile_size or w < tile_size: pad_h = max(0, tile_size - h) pad_w = max(0, tile_size - w) - image = np.pad( - image, - ((0, pad_h), (0, pad_w), (0, 0)), - mode='reflect' - ) - + image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") + # Generate single embedding datacube = self._prepare_datacube( - image, self.sensor_type, self.lat, self.lon, - self.date, self.gsd_override + image, + self.sensor_type, + self.lat, + self.lon, + self.date, + self.gsd_override, ) - + with torch.no_grad(): encoded_patches, _, _, _ = self.module.model.encoder(datacube) # Extract class token (global embedding) embedding = encoded_patches[:, 0, :].cpu().numpy() - + return { - 'embeddings': embedding, - 'tile_coords': [(0, 0, h, w)], - 'image_shape': (h, w), - 'sensor_type': self.sensor_type, - 'lat': self.lat, - 'lon': self.lon, - 'date': self.date.isoformat() if self.date else None, - 'num_tiles': 1 + "embeddings": embedding, + "tile_coords": [(0, 0, h, w)], + "image_shape": (h, w), + "sensor_type": self.sensor_type, + "lat": self.lat, + "lon": self.lon, + "date": self.date.isoformat() if self.date else None, + "num_tiles": 1, } - + else: # Process as overlapping tiles step_size = int(tile_size * (1 - overlap)) embeddings = [] tile_coords = [] - + for y in range(0, h - tile_size + 1, step_size): for x in range(0, w - tile_size + 1, step_size): # Extract tile - tile = image[y:y+tile_size, x:x+tile_size] - + tile = image[y : y + tile_size, x : x + tile_size] + # Prepare datacube for this tile datacube = self._prepare_datacube( - tile, self.sensor_type, self.lat, self.lon, - self.date, self.gsd_override + tile, + self.sensor_type, + self.lat, + self.lon, + self.date, + self.gsd_override, ) - + # Generate embedding with torch.no_grad(): encoded_patches, _, _, _ = self.module.model.encoder(datacube) embedding = encoded_patches[:, 0, :].cpu().numpy() - + embeddings.append(embedding) - tile_coords.append((x, y, x+tile_size, y+tile_size)) - + tile_coords.append((x, y, x + tile_size, y + tile_size)) + return { - 'embeddings': np.vstack(embeddings), - 'tile_coords': tile_coords, - 'image_shape': (h, w), - 'sensor_type': self.sensor_type, - 'lat': self.lat, - 'lon': self.lon, - 'date': self.date.isoformat() if self.date else None, - 'num_tiles': len(embeddings) + "embeddings": np.vstack(embeddings), + "tile_coords": tile_coords, + "image_shape": (h, w), + "sensor_type": self.sensor_type, + "lat": self.lat, + "lon": self.lon, + "date": self.date.isoformat() if self.date else None, + "num_tiles": len(embeddings), } - + def save_embeddings( - self, - embeddings_result: Dict[str, Any], - output_path: str, - format: str = 'npz' + self, embeddings_result: Dict[str, Any], output_path: str, format: str = "npz" ): """ Save embeddings to file. - + Args: embeddings_result: Result from generate_embeddings() output_path: Output file path format: Output format ('npz', 'pt') """ output_path = check_file_path(output_path) - - if format == 'npz': + + if format == "npz": np.savez_compressed( output_path, - embeddings=embeddings_result['embeddings'], - tile_coords=np.array(embeddings_result['tile_coords']), - image_shape=np.array(embeddings_result['image_shape']), - sensor_type=embeddings_result['sensor_type'], - lat=embeddings_result['lat'], - lon=embeddings_result['lon'], - date=embeddings_result['date'], - num_tiles=embeddings_result['num_tiles'] + embeddings=embeddings_result["embeddings"], + tile_coords=np.array(embeddings_result["tile_coords"]), + image_shape=np.array(embeddings_result["image_shape"]), + sensor_type=embeddings_result["sensor_type"], + lat=embeddings_result["lat"], + lon=embeddings_result["lon"], + date=embeddings_result["date"], + num_tiles=embeddings_result["num_tiles"], ) - elif format == 'pt': + elif format == "pt": torch.save(embeddings_result, output_path) else: raise ValueError(f"Unsupported format: {format}") - + print(f"Saved embeddings to {output_path}") def load_embeddings(file_path: str) -> Dict[str, Any]: """ Load embeddings from file. - + Args: file_path: Path to embeddings file - + Returns: Embeddings dictionary """ - if file_path.endswith('.npz'): + if file_path.endswith(".npz"): data = np.load(file_path, allow_pickle=True) return { - 'embeddings': data['embeddings'], - 'tile_coords': data['tile_coords'].tolist(), - 'image_shape': tuple(data['image_shape']), - 'sensor_type': str(data['sensor_type']), - 'lat': float(data['lat']), - 'lon': float(data['lon']), - 'date': str(data['date']) if data['date'] != 'None' else None, - 'num_tiles': int(data['num_tiles']) + "embeddings": data["embeddings"], + "tile_coords": data["tile_coords"].tolist(), + "image_shape": tuple(data["image_shape"]), + "sensor_type": str(data["sensor_type"]), + "lat": float(data["lat"]), + "lon": float(data["lon"]), + "date": str(data["date"]) if data["date"] != "None" else None, + "num_tiles": int(data["num_tiles"]), } - elif file_path.endswith('.pt'): - return torch.load(file_path, map_location='cpu') + elif file_path.endswith(".pt"): + return torch.load(file_path, map_location="cpu") else: - raise ValueError(f"Unsupported file format: {file_path}") \ No newline at end of file + raise ValueError(f"Unsupported file format: {file_path}")