diff --git a/rslearn/config/dataset.py b/rslearn/config/dataset.py index c4312299..6e478200 100644 --- a/rslearn/config/dataset.py +++ b/rslearn/config/dataset.py @@ -25,7 +25,7 @@ from upath import UPath from rslearn.log_utils import get_logger -from rslearn.utils import PixelBounds, Projection +from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor from rslearn.utils.raster_format import RasterFormat from rslearn.utils.vector_format import VectorFormat @@ -210,22 +210,12 @@ def get_final_projection_and_bounds( Returns: tuple of updated projection and bounds with zoom offset applied """ - if self.zoom_offset == 0: - return projection, bounds - projection = Projection( - projection.crs, - projection.x_resolution / (2**self.zoom_offset), - projection.y_resolution / (2**self.zoom_offset), - ) - if self.zoom_offset > 0: - zoom_factor = 2**self.zoom_offset - bounds = tuple(x * zoom_factor for x in bounds) # type: ignore + if self.zoom_offset >= 0: + factor = ResolutionFactor(numerator=2**self.zoom_offset) else: - bounds = tuple( - x // (2 ** (-self.zoom_offset)) - for x in bounds # type: ignore - ) - return projection, bounds + factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset)) + + return (factor.multiply_projection(projection), factor.multiply_bounds(bounds)) @field_validator("format", mode="before") @classmethod diff --git a/rslearn/train/dataset.py b/rslearn/train/dataset.py index ca880213..747579ec 100644 --- a/rslearn/train/dataset.py +++ b/rslearn/train/dataset.py @@ -8,6 +8,7 @@ import tempfile import time import uuid +from dataclasses import dataclass from typing import Any import torch @@ -24,7 +25,7 @@ from rslearn.log_utils import get_logger from rslearn.train.tasks import Task from rslearn.utils.feature import Feature -from rslearn.utils.geometry import PixelBounds +from rslearn.utils.geometry import PixelBounds, ResolutionFactor from rslearn.utils.mp import star_imap_unordered from .transforms import Sequential @@ -124,56 +125,50 @@ def get_sampler(self, dataset: "ModelDataset") -> torch.utils.data.Sampler: ) +@dataclass class DataInput: """Specification of a piece of data from a window that is needed for training. The DataInput includes which layer(s) the data can be obtained from for each window. + + Args: + data_type: either "raster" or "vector" + layers: list of layer names that this input can be read from. + bands: the bands to read, if this is a raster. + required: whether examples lacking one of these layers should be skipped + passthrough: whether to expose this to the model even if it isn't returned + by any task + is_target: whether this DataInput represents a target for the task. Targets + are not read during prediction phase. + dtype: data type to load the raster as + load_all_layers: whether to load all of the layers specified in the list of + layer names. By default, we randomly pick one layer to read. When + reading multiple layers, the images are stacked on the channel + dimension. This option will also cause the dataset to only include + windows where all of the layers are materialized (by default, only + windows with none of the layers materialized would be excluded). + load_all_item_groups: whether to load all item groups in the layer(s) we + are reading from. By default, we assume the specified layer name is of + the form "{layer_name}.{group_idx}" and read that item group only. With + this option enabled, we ignore the group_idx and read all item groups. + resolution_factor: raster inputs are read by default at the window + resolution. This is a multiplier to read at a different resolution, + e.g. if the window resolution is 64x64 at 10 m/pixel, and + resolution_factor=2, then the input is read as 32x32 at 20 m/pixel. + resampling: resampling method (default nearest neighbor). """ - def __init__( - self, - data_type: str, - layers: list[str], - bands: list[str] | None = None, - required: bool = True, - passthrough: bool = False, - is_target: bool = False, - dtype: DType = DType.FLOAT32, - load_all_layers: bool = False, - load_all_item_groups: bool = False, - ) -> None: - """Initialize a new DataInput. - - Args: - data_type: either "raster" or "vector" - layers: list of layer names that this input can be read from. - bands: the bands to read, if this is a raster. - required: whether examples lacking one of these layers should be skipped - passthrough: whether to expose this to the model even if it isn't returned - by any task - is_target: whether this DataInput represents a target for the task. Targets - are not read during prediction phase. - dtype: data type to load the raster as - load_all_layers: whether to load all of the layers specified in the list of - layer names. By default, we randomly pick one layer to read. When - reading multiple layers, the images are stacked on the channel - dimension. This option will also cause the dataset to only include - windows where all of the layers are materialized (by default, only - windows with none of the layers materialized would be excluded). - load_all_item_groups: whether to load all item groups in the layer(s) we - are reading from. By default, we assume the specified layer name is of - the form "{layer_name}.{group_idx}" and read that item group only. With - this option enabled, we ignore the group_idx and read all item groups. - """ - self.data_type = data_type - self.layers = layers - self.bands = bands - self.required = required - self.passthrough = passthrough - self.is_target = is_target - self.dtype = dtype - self.load_all_layers = load_all_layers - self.load_all_item_groups = load_all_item_groups + data_type: str + layers: list[str] + bands: list[str] | None = None + required: bool = True + passthrough: bool = False + is_target: bool = False + dtype: DType = DType.FLOAT32 + load_all_layers: bool = False + load_all_item_groups: bool = False + resolution_factor: ResolutionFactor = ResolutionFactor() + resampling: Resampling = Resampling.nearest def read_raster_layer_for_data_input( @@ -231,15 +226,23 @@ def read_raster_layer_for_data_input( + f"window {window.name} layer {layer_name} group {group_idx}" ) + # Get the projection and bounds to read under (multiply window resolution # by + # the specified resolution factor). + final_projection = data_input.resolution_factor.multiply_projection( + window.projection + ) + final_bounds = data_input.resolution_factor.multiply_bounds(bounds) + image = torch.zeros( - (len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]), + ( + len(needed_bands), + final_bounds[3] - final_bounds[1], + final_bounds[2] - final_bounds[0], + ), dtype=get_torch_dtype(data_input.dtype), ) for band_set, src_indexes, dst_indexes in needed_sets_and_indexes: - final_projection, final_bounds = band_set.get_final_projection_and_bounds( - window.projection, bounds - ) if band_set.format is None: raise ValueError(f"No format specified for {layer_name}") raster_format = band_set.instantiate_raster_format() @@ -247,44 +250,16 @@ def read_raster_layer_for_data_input( layer_name, band_set.bands, group_idx=group_idx ) - # Previously we always read in the native projection of the data, and then - # zoom in or out (the resolution must be a power of two off) to match the - # window's resolution. - # However, this fails if the bounds are not multiples of the resolution factor. - # So we fallback to reading directly in the window projection if that is the - # case (which may be a bit slower). - is_bounds_zoomable = True - if band_set.zoom_offset < 0: - zoom_factor = 2 ** (-band_set.zoom_offset) - is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == ( - bounds[2] - bounds[0] - ) and (final_bounds[3] - final_bounds[1]) * zoom_factor == ( - bounds[3] - bounds[1] - ) - - if is_bounds_zoomable: - src = raster_format.decode_raster( - raster_dir, final_projection, final_bounds - ) - - # Resize to patch size if needed. - # This is for band sets that are stored at a lower resolution. - # Here we assume that it is a multiple. - if src.shape[1:3] != image.shape[1:3]: - if src.shape[1] < image.shape[1]: - factor = image.shape[1] // src.shape[1] - src = src.repeat(repeats=factor, axis=1).repeat( - repeats=factor, axis=2 - ) - else: - factor = src.shape[1] // image.shape[1] - src = src[:, ::factor, ::factor] - - else: - src = raster_format.decode_raster( - raster_dir, window.projection, bounds, resampling=Resampling.nearest - ) + # TODO: previously we try to read based on band_set.zoom_offset when possible, + # and handle zooming in with torch.repeat (if resampling method is nearest + # neighbor). However, we have not benchmarked whether this actually improves + # data loading speed, so for simplicity, for now we let rasterio handle the + # resampling. If it really is much faster to handle it via torch, then it may + # make sense to bring back that functionality. + src = raster_format.decode_raster( + raster_dir, final_projection, final_bounds, resampling=Resampling.nearest + ) image[dst_indexes, :, :] = torch.as_tensor( src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype()) ) diff --git a/rslearn/utils/geometry.py b/rslearn/utils/geometry.py index 04136e17..e0093d0e 100644 --- a/rslearn/utils/geometry.py +++ b/rslearn/utils/geometry.py @@ -116,6 +116,77 @@ def deserialize(d: dict) -> "Projection": WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1) +class ResolutionFactor: + """Multiplier for the resolution in a Projection. + + The multiplier is either an integer x, or the inverse of an integer (1/x). + + Factors greater than 1 increase the projection_units/pixel resolution, coarsening + the result (less pixels per projection unit). Factors less than 1 make it finer + (more pixels). + """ + + def __init__(self, numerator: int = 1, denominator: int = 1): + """Create a new ResolutionFactor. + + Args: + numerator: the numerator of the fraction. + denominator: the denominator of the fraction. If set, numerator must be 1. + """ + if numerator != 1 and denominator != 1: + raise ValueError("one of numerator or denominator must be 1") + if not isinstance(numerator, int) or not isinstance(denominator, int): + raise ValueError("numerator and denominator must be integers") + if numerator < 1 or denominator < 1: + raise ValueError("numerator and denominator must be >= 1") + self.numerator = numerator + self.denominator = denominator + + def multiply_projection(self, projection: Projection) -> Projection: + """Multiply the projection by this factor.""" + if self.denominator > 1: + return Projection( + projection.crs, + projection.x_resolution // self.denominator, + projection.y_resolution // self.denominator, + ) + else: + return Projection( + projection.crs, + projection.x_resolution * self.numerator, + projection.y_resolution * self.numerator, + ) + + def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds: + """Multiply the bounds by this factor. + + When coarsening, the width and height of the given bounds must be a multiple of + the numerator. + """ + if self.denominator > 1: + return ( + bounds[0] * self.denominator, + bounds[1] * self.denominator, + bounds[2] * self.denominator, + bounds[3] * self.denominator, + ) + else: + # Verify the width and height are multiples of the numerator. + # Otherwise the new width and height is not an integer. + width = bounds[2] - bounds[0] + height = bounds[3] - bounds[1] + if width % self.numerator != 0 or height % self.numerator != 0: + raise ValueError( + f"width {width} or height {height} is not a multiple of the resolution factor {self.numerator}" + ) + return ( + bounds[0] // self.numerator, + bounds[1] // self.numerator, + bounds[2] // self.numerator, + bounds[3] // self.numerator, + ) + + class STGeometry: """A spatiotemporal geometry. diff --git a/rslearn/utils/jsonargparse.py b/rslearn/utils/jsonargparse.py index 2635168c..ff148973 100644 --- a/rslearn/utils/jsonargparse.py +++ b/rslearn/utils/jsonargparse.py @@ -8,6 +8,7 @@ from upath import UPath from rslearn.config.dataset import LayerConfig +from rslearn.utils.geometry import ResolutionFactor if TYPE_CHECKING: from rslearn.data_sources.data_source import DataSourceContext @@ -91,6 +92,44 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext": ) +def resolution_factor_serializer(v: ResolutionFactor) -> str: + """Serialize ResolutionFactor for jsonargparse. + + Args: + v: the ResolutionFactor object. + + Returns: + the ResolutionFactor encoded to string + """ + return f"{v.numerator}/{v.denominator}" + + +def resolution_factor_deserializer(v: int | str) -> ResolutionFactor: + """Deserialize ResolutionFactor for jsonargparse. + + Args: + v: the encoded ResolutionFactor. + + Returns: + the decoded ResolutionFactor object + """ + if isinstance(v, int): + return ResolutionFactor(numerator=v) + elif isinstance(v, str): + parts = v.split("/") + if len(parts) == 1: + return ResolutionFactor(numerator=int(parts[0])) + elif len(parts) == 2: + return ResolutionFactor( + numerator=int(parts[0]), + denominator=int(parts[1]), + ) + else: + raise ValueError("expected resolution factor to be of the form x or 1/x") + else: + raise ValueError("expected resolution factor to be str or int") + + def init_jsonargparse() -> None: """Initialize custom jsonargparse serializers.""" global INITIALIZED @@ -100,6 +139,9 @@ def init_jsonargparse() -> None: jsonargparse.typing.register_type( datetime, datetime_serializer, datetime_deserializer ) + jsonargparse.typing.register_type( + ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer + ) from rslearn.data_sources.data_source import DataSourceContext diff --git a/tests/integration/train/test_dataset.py b/tests/integration/train/test_dataset.py index ef8b7675..8557a235 100644 --- a/tests/integration/train/test_dataset.py +++ b/tests/integration/train/test_dataset.py @@ -2,13 +2,30 @@ import pathlib from typing import Any +import lightning.pytorch as pl +import numpy as np +import pytest +import torch from rasterio.crs import CRS from upath import UPath +from rslearn.config import BandSetConfig, DatasetConfig, DType, LayerConfig, LayerType from rslearn.dataset import Dataset, Window -from rslearn.train.dataset import ModelDataset, SplitConfig +from rslearn.models.conv import Conv +from rslearn.models.module_wrapper import EncoderModuleWrapper +from rslearn.models.pick_features import PickFeatures +from rslearn.models.singletask import SingleTaskModel +from rslearn.train.data_module import RslearnDataModule +from rslearn.train.dataset import DataInput, ModelDataset, SplitConfig +from rslearn.train.lightning_module import RslearnLightningModule +from rslearn.train.optimizer import AdamW from rslearn.train.tasks.classification import ClassificationTask -from rslearn.utils.geometry import Projection +from rslearn.train.tasks.per_pixel_regression import ( + PerPixelRegressionHead, + PerPixelRegressionTask, +) +from rslearn.utils.geometry import Projection, ResolutionFactor +from rslearn.utils.raster_format import GeotiffRasterFormat class TestDataset: @@ -71,3 +88,158 @@ def test_empty_dataset(self, tmp_path: pathlib.Path) -> None: workers=4, ) assert len(dataset) == 0 + + +class TestResolutionFactor: + """Integration test for ModelDataset with DataInputs that have resolution factor. + + We verify we can train a model to input 4x4 but output 2x2 for PerPixelRegressionTask. + """ + + def create_dataset(self, ds_path: UPath) -> Dataset: + """Write the dataset config and return dataset.""" + cfg = DatasetConfig( + layers=dict( + image=LayerConfig( + type=LayerType.RASTER, + band_sets=[ + BandSetConfig( + dtype=DType.UINT8, + bands=["B1"], + ) + ], + ), + label=LayerConfig( + type=LayerType.RASTER, + band_sets=[ + BandSetConfig( + dtype=DType.UINT8, + bands=["B1"], + ) + ], + ), + ) + ) + with (ds_path / "config.json").open("w") as f: + f.write(cfg.model_dump_json()) + return Dataset(ds_path) + + def add_window(self, ds_path: UPath, group: str, name: str) -> Window: + """Add a window with the specified name.""" + window = Window( + path=Window.get_window_root(ds_path, group, name), + group=group, + name=name, + projection=Projection(CRS.from_epsg(3857), 1, -1), + bounds=(0, 0, 4, 4), + time_range=None, + ) + window.save() + + # Add image layer. + GeotiffRasterFormat().encode_raster( + window.get_raster_dir("image", ["B1"]), + window.projection, + window.bounds, + np.ones((1, 4, 4), dtype=np.uint8), + ) + window.mark_layer_completed("image") + + # Add label layer. + GeotiffRasterFormat().encode_raster( + window.get_raster_dir("label", ["B1"]), + window.projection, + window.bounds, + 2 * np.ones((1, 4, 4), dtype=np.uint8), + ) + window.mark_layer_completed("label") + + return window + + def test_per_pixel_regression(self, tmp_path: pathlib.Path) -> None: + """Run the test with PerPixelRegressionTask.""" + ds_path = UPath(tmp_path) + self.create_dataset(ds_path) + for idx in range(16): + self.add_window(ds_path, "default", f"window{idx}") + + task = PerPixelRegressionTask() + + # Create the data module. + data_module = RslearnDataModule( + task=task, + inputs=dict( + image=DataInput( + data_type="raster", + layers=["image"], + bands=["B1"], + dtype=DType.FLOAT32, + passthrough=True, + ), + targets=DataInput( + data_type="raster", + layers=["label"], + bands=["B1"], + dtype=DType.INT32, + is_target=True, + # Here we set the resolution factor so the target is 2x2. + resolution_factor=ResolutionFactor(numerator=2), + ), + ), + path=str(ds_path), + ) + + # Create the model architecture. It is just two convs, one to downsample and + # another to make the prediction. + model = SingleTaskModel( + encoder=[ + EncoderModuleWrapper( + module=Conv( + in_channels=1, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + ), + ), + ], + decoder=[ + Conv( + in_channels=32, + out_channels=1, + kernel_size=3, + activation=torch.nn.Identity(), + ), + PickFeatures(indexes=[0], collapse=True), + PerPixelRegressionHead(), + ], + ) + + # Perform fit. + lm = RslearnLightningModule( + model, + task=task, + optimizer=AdamW(lr=0.001), + ) + trainer = pl.Trainer(max_epochs=10) + trainer.fit(lm, datamodule=data_module) + + # Make sure model produces the right output. + model.eval() + output = ( + model( + [ + { + "image": torch.ones((1, 4, 4), dtype=torch.float32), + } + ] + )["outputs"] + .detach() + .numpy() + ) + print(output) + # Index into BCHW tensor. + assert output[0, 0, 0] == pytest.approx(2, abs=0.01) + assert output[0, 0, 1] == pytest.approx(2, abs=0.01) + assert output[0, 1, 0] == pytest.approx(2, abs=0.01) + assert output[0, 1, 1] == pytest.approx(2, abs=0.01)