Skip to content

Releases: Multihuntr/gff

Model weights v1

30 Sep 03:53

Choose a tag to compare

This release contains the model weights from our paper. These are only guaranteed to work on this tag, so make sure you git checkout v1 first.

There were 5 folds. Thus, for each model, we provide 5 sets of weights, one for each fold. You may wish to ensemble these in practice.

You can use these weights from within this repo with:

import torch
import yaml

import gff.models.creation

# Paths to downloaded files
config_path = "path/to/config.yml"
weights_path = "path/to/weights.th"

# Create model
with open(config_path) as f:
    C = yaml.safe_load(f)
model = gff.models.creation.create(C)

# Load weights into model
checkpoint = torch.load(weights_path)
model.load_state_dict(checkpoint["model"])
model.eval()
model.cuda()

# Dummy forward; obtaining real data is left to the user. Comments are left as a guide. It is not easy.
B, T = 2, C["weather_window"]
cH, cW = 32, 32
fH, fW = 224, 224
ex = {
    # ERA5/ERA5-land: see scripts/dl-era5-land.py, scripts/export-context.py and gff.data_sources.load_exported_era5_nc
    "era5": torch.randn((B, T, len(C["era5_keys"]), cH, cW)).cuda(),
    "era5_land": torch.randn((B, T, len(C["era5_land_keys"]), cH, cW)).cuda(),
    # GloFAS: see scripts/export-glofas.py, scripts/export-context.py and gff.data_sources.load_glofas
    "glofas": torch.randn((B, T, len(C["glofas_keys"]), cH, cW)).cuda(),
    # HydroATLAS: download rasterised hydroatlas, see gff.data_sources.load_pregenerated_raster
    "hydroatlas_basin": torch.randn((B, len(C["hydroatlas_keys"]), cH, cW)).cuda(),
    # DEM (context): see scripts/export-context.py and gff.data_sources.load_pregenerated_raster
    "dem_context": torch.randn((B, 1, cH, cW)).cuda(),
    # Sentinel-1: see ./preprocessing, gff.data_sources.download_s1, gff.data_sources.export_s1 and scripts/export-local.py
    "s1": torch.randn(B, 2, fH, fW).cuda(),
    "s1_lead_days": torch.randint(0, 20, (B,)).cuda(),  # computed field
    # DEM (local): see scripts/export-local.py and gff.data_sources.get_dem
    "dem_local": torch.randn((B, 1, fH, fW)).cuda(),
    # HAND: see scripts/export-local.py and gff.data_sources.get_hand
    "hand": torch.randn((B, 1, fH, fW)).cuda(),
}

output = model(ex)  # Note: model automatically handles normalisation
print(output.shape)  # Water/No-water segmentation: B, 2, 224, 224

If you want to use these weights outside this repo, then you have two main options.

  1. Copy the gff/models folder in its entirety into your project. This depends only on pytorch and numpy.
  2. Clone repo, and run pip install -e . from this project root. Then you can import gff in your own project to access all of our dataset utility functions. However, this depends on all of environment.yml.