Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions rslearn/config/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from upath import UPath

from rslearn.log_utils import get_logger
from rslearn.utils import PixelBounds, Projection
from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
from rslearn.utils.raster_format import RasterFormat
from rslearn.utils.vector_format import VectorFormat

Expand Down Expand Up @@ -210,22 +210,12 @@ def get_final_projection_and_bounds(
Returns:
tuple of updated projection and bounds with zoom offset applied
"""
if self.zoom_offset == 0:
return projection, bounds
projection = Projection(
projection.crs,
projection.x_resolution / (2**self.zoom_offset),
projection.y_resolution / (2**self.zoom_offset),
)
if self.zoom_offset > 0:
zoom_factor = 2**self.zoom_offset
bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
if self.zoom_offset >= 0:
factor = ResolutionFactor(numerator=2**self.zoom_offset)
else:
bounds = tuple(
x // (2 ** (-self.zoom_offset))
for x in bounds # type: ignore
)
return projection, bounds
factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))

return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))

@field_validator("format", mode="before")
@classmethod
Expand Down
147 changes: 61 additions & 86 deletions rslearn/train/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile
import time
import uuid
from dataclasses import dataclass
from typing import Any

import torch
Expand All @@ -24,7 +25,7 @@
from rslearn.log_utils import get_logger
from rslearn.train.tasks import Task
from rslearn.utils.feature import Feature
from rslearn.utils.geometry import PixelBounds
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
from rslearn.utils.mp import star_imap_unordered

from .transforms import Sequential
Expand Down Expand Up @@ -124,56 +125,50 @@ def get_sampler(self, dataset: "ModelDataset") -> torch.utils.data.Sampler:
)


@dataclass
class DataInput:
"""Specification of a piece of data from a window that is needed for training.

The DataInput includes which layer(s) the data can be obtained from for each window.

Args:
data_type: either "raster" or "vector"
layers: list of layer names that this input can be read from.
bands: the bands to read, if this is a raster.
required: whether examples lacking one of these layers should be skipped
passthrough: whether to expose this to the model even if it isn't returned
by any task
is_target: whether this DataInput represents a target for the task. Targets
are not read during prediction phase.
dtype: data type to load the raster as
load_all_layers: whether to load all of the layers specified in the list of
layer names. By default, we randomly pick one layer to read. When
reading multiple layers, the images are stacked on the channel
dimension. This option will also cause the dataset to only include
windows where all of the layers are materialized (by default, only
windows with none of the layers materialized would be excluded).
load_all_item_groups: whether to load all item groups in the layer(s) we
are reading from. By default, we assume the specified layer name is of
the form "{layer_name}.{group_idx}" and read that item group only. With
this option enabled, we ignore the group_idx and read all item groups.
resolution_factor: raster inputs are read by default at the window
resolution. This is a multiplier to read at a different resolution,
e.g. if the window resolution is 64x64 at 10 m/pixel, and
resolution_factor=2, then the input is read as 32x32 at 20 m/pixel.
resampling: resampling method (default nearest neighbor).
"""

def __init__(
self,
data_type: str,
layers: list[str],
bands: list[str] | None = None,
required: bool = True,
passthrough: bool = False,
is_target: bool = False,
dtype: DType = DType.FLOAT32,
load_all_layers: bool = False,
load_all_item_groups: bool = False,
) -> None:
"""Initialize a new DataInput.

Args:
data_type: either "raster" or "vector"
layers: list of layer names that this input can be read from.
bands: the bands to read, if this is a raster.
required: whether examples lacking one of these layers should be skipped
passthrough: whether to expose this to the model even if it isn't returned
by any task
is_target: whether this DataInput represents a target for the task. Targets
are not read during prediction phase.
dtype: data type to load the raster as
load_all_layers: whether to load all of the layers specified in the list of
layer names. By default, we randomly pick one layer to read. When
reading multiple layers, the images are stacked on the channel
dimension. This option will also cause the dataset to only include
windows where all of the layers are materialized (by default, only
windows with none of the layers materialized would be excluded).
load_all_item_groups: whether to load all item groups in the layer(s) we
are reading from. By default, we assume the specified layer name is of
the form "{layer_name}.{group_idx}" and read that item group only. With
this option enabled, we ignore the group_idx and read all item groups.
"""
self.data_type = data_type
self.layers = layers
self.bands = bands
self.required = required
self.passthrough = passthrough
self.is_target = is_target
self.dtype = dtype
self.load_all_layers = load_all_layers
self.load_all_item_groups = load_all_item_groups
data_type: str
layers: list[str]
bands: list[str] | None = None
required: bool = True
passthrough: bool = False
is_target: bool = False
dtype: DType = DType.FLOAT32
load_all_layers: bool = False
load_all_item_groups: bool = False
resolution_factor: ResolutionFactor = ResolutionFactor()
resampling: Resampling = Resampling.nearest


def read_raster_layer_for_data_input(
Expand Down Expand Up @@ -231,60 +226,40 @@ def read_raster_layer_for_data_input(
+ f"window {window.name} layer {layer_name} group {group_idx}"
)

# Get the projection and bounds to read under (multiply window resolution # by
# the specified resolution factor).
final_projection = data_input.resolution_factor.multiply_projection(
window.projection
)
final_bounds = data_input.resolution_factor.multiply_bounds(bounds)

image = torch.zeros(
(len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
(
len(needed_bands),
final_bounds[3] - final_bounds[1],
final_bounds[2] - final_bounds[0],
),
dtype=get_torch_dtype(data_input.dtype),
)

for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
final_projection, final_bounds = band_set.get_final_projection_and_bounds(
window.projection, bounds
)
if band_set.format is None:
raise ValueError(f"No format specified for {layer_name}")
raster_format = band_set.instantiate_raster_format()
raster_dir = window.get_raster_dir(
layer_name, band_set.bands, group_idx=group_idx
)

# Previously we always read in the native projection of the data, and then
# zoom in or out (the resolution must be a power of two off) to match the
# window's resolution.
# However, this fails if the bounds are not multiples of the resolution factor.
# So we fallback to reading directly in the window projection if that is the
# case (which may be a bit slower).
is_bounds_zoomable = True
if band_set.zoom_offset < 0:
zoom_factor = 2 ** (-band_set.zoom_offset)
is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == (
bounds[2] - bounds[0]
) and (final_bounds[3] - final_bounds[1]) * zoom_factor == (
bounds[3] - bounds[1]
)

if is_bounds_zoomable:
src = raster_format.decode_raster(
raster_dir, final_projection, final_bounds
)

# Resize to patch size if needed.
# This is for band sets that are stored at a lower resolution.
# Here we assume that it is a multiple.
if src.shape[1:3] != image.shape[1:3]:
if src.shape[1] < image.shape[1]:
factor = image.shape[1] // src.shape[1]
src = src.repeat(repeats=factor, axis=1).repeat(
repeats=factor, axis=2
)
else:
factor = src.shape[1] // image.shape[1]
src = src[:, ::factor, ::factor]

else:
src = raster_format.decode_raster(
raster_dir, window.projection, bounds, resampling=Resampling.nearest
)
# TODO: previously we try to read based on band_set.zoom_offset when possible,
# and handle zooming in with torch.repeat (if resampling method is nearest
# neighbor). However, we have not benchmarked whether this actually improves
# data loading speed, so for simplicity, for now we let rasterio handle the
# resampling. If it really is much faster to handle it via torch, then it may
# make sense to bring back that functionality.

src = raster_format.decode_raster(
raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
)
image[dst_indexes, :, :] = torch.as_tensor(
src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
)
Expand Down
71 changes: 71 additions & 0 deletions rslearn/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,77 @@ def deserialize(d: dict) -> "Projection":
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)


class ResolutionFactor:
"""Multiplier for the resolution in a Projection.

The multiplier is either an integer x, or the inverse of an integer (1/x).

Factors greater than 1 increase the projection_units/pixel resolution, coarsening
the result (less pixels per projection unit). Factors less than 1 make it finer
(more pixels).
"""

def __init__(self, numerator: int = 1, denominator: int = 1):
"""Create a new ResolutionFactor.

Args:
numerator: the numerator of the fraction.
denominator: the denominator of the fraction. If set, numerator must be 1.
"""
if numerator != 1 and denominator != 1:
raise ValueError("one of numerator or denominator must be 1")
if not isinstance(numerator, int) or not isinstance(denominator, int):
raise ValueError("numerator and denominator must be integers")
if numerator < 1 or denominator < 1:
raise ValueError("numerator and denominator must be >= 1")
self.numerator = numerator
self.denominator = denominator

def multiply_projection(self, projection: Projection) -> Projection:
"""Multiply the projection by this factor."""
if self.denominator > 1:
return Projection(
projection.crs,
projection.x_resolution // self.denominator,
projection.y_resolution // self.denominator,
)
else:
return Projection(
projection.crs,
projection.x_resolution * self.numerator,
projection.y_resolution * self.numerator,
)

def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
"""Multiply the bounds by this factor.

When coarsening, the width and height of the given bounds must be a multiple of
the numerator.
"""
if self.denominator > 1:
return (
bounds[0] * self.denominator,
bounds[1] * self.denominator,
bounds[2] * self.denominator,
bounds[3] * self.denominator,
)
else:
# Verify the width and height are multiples of the numerator.
# Otherwise the new width and height is not an integer.
width = bounds[2] - bounds[0]
height = bounds[3] - bounds[1]
if width % self.numerator != 0 or height % self.numerator != 0:
raise ValueError(
f"width {width} or height {height} is not a multiple of the resolution factor {self.numerator}"
)
return (
bounds[0] // self.numerator,
bounds[1] // self.numerator,
bounds[2] // self.numerator,
bounds[3] // self.numerator,
)


class STGeometry:
"""A spatiotemporal geometry.

Expand Down
42 changes: 42 additions & 0 deletions rslearn/utils/jsonargparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from upath import UPath

from rslearn.config.dataset import LayerConfig
from rslearn.utils.geometry import ResolutionFactor

if TYPE_CHECKING:
from rslearn.data_sources.data_source import DataSourceContext
Expand Down Expand Up @@ -91,6 +92,44 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
)


def resolution_factor_serializer(v: ResolutionFactor) -> str:
"""Serialize ResolutionFactor for jsonargparse.

Args:
v: the ResolutionFactor object.

Returns:
the ResolutionFactor encoded to string
"""
return f"{v.numerator}/{v.denominator}"


def resolution_factor_deserializer(v: int | str) -> ResolutionFactor:
"""Deserialize ResolutionFactor for jsonargparse.

Args:
v: the encoded ResolutionFactor.

Returns:
the decoded ResolutionFactor object
"""
if isinstance(v, int):
return ResolutionFactor(numerator=v)
elif isinstance(v, str):
parts = v.split("/")
if len(parts) == 1:
return ResolutionFactor(numerator=int(parts[0]))
elif len(parts) == 2:
return ResolutionFactor(
numerator=int(parts[0]),
denominator=int(parts[1]),
)
else:
raise ValueError("expected resolution factor to be of the form x or 1/x")
else:
raise ValueError("expected resolution factor to be str or int")


def init_jsonargparse() -> None:
"""Initialize custom jsonargparse serializers."""
global INITIALIZED
Expand All @@ -100,6 +139,9 @@ def init_jsonargparse() -> None:
jsonargparse.typing.register_type(
datetime, datetime_serializer, datetime_deserializer
)
jsonargparse.typing.register_type(
ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
)

from rslearn.data_sources.data_source import DataSourceContext

Expand Down
Loading
Loading