From 8a479910c6da7fa6baa9f73e4811c343f79f0e53 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 28 Jan 2025 11:09:03 +0100 Subject: [PATCH] Reproject method for ImageData (#782) * Add reproject method to ImageData class with tests * Add reproject method for ImageData objects in CHANGES.md * Remove unused rasterio import and clean up test cases for reproject method * lint * fix and update tests * relaxe rasterio version limit * update changelog --------- Co-authored-by: vincentsarago --- CHANGES.md | 11 +++++++ pyproject.toml | 2 +- rio_tiler/models.py | 48 +++++++++++++++++++++++++++++-- tests/test_models.py | 68 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 4727d893..9d3c34c7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,16 @@ # Unreleased (TBD) +* update rasterio dependency to `>=1.4.0` + +* add `reproject` method for `ImageData` objects (author @emmanuelmathot, https://github.com/cogeotiff/rio-tiler/pull/782) + + ```python + from rio_tiler.models import ImageData + + img = ImageData(numpy.zeros((3, 256, 256), crs=CRS.from_epsg(4326), dtype="uint8")) + img_3857 = img.reproject("epsg:3857") + ``` + * add `indexes` parameter for `XarrayReader` methods. As for Rasterio, the indexes values start at `1`. ```python diff --git a/pyproject.toml b/pyproject.toml index 4375c486..54902a1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "morecantile>=5.0,<7.0", "pydantic~=2.0", "pystac>=0.5.4", - "rasterio>=1.3.0", + "rasterio>=1.4.0", "color-operations", "typing-extensions", "importlib_resources>=1.1.0; python_version < '3.9'", diff --git a/rio_tiler/models.py b/rio_tiler/models.py index 54dfd5ba..e6c955c4 100644 --- a/rio_tiler/models.py +++ b/rio_tiler/models.py @@ -14,13 +14,13 @@ from rasterio.coords import BoundingBox from rasterio.crs import CRS from rasterio.dtypes import dtype_ranges -from rasterio.enums import ColorInterp +from rasterio.enums import ColorInterp, Resampling from rasterio.errors import NotGeoreferencedWarning from rasterio.features import rasterize from rasterio.io import MemoryFile from rasterio.plot import reshape_as_image -from rasterio.transform import from_bounds -from rasterio.warp import transform_geom +from rasterio.transform import array_bounds, from_bounds +from rasterio.warp import calculate_default_transform, reproject, transform_geom from typing_extensions import Self from rio_tiler.colormap import apply_cmap @@ -34,6 +34,7 @@ IntervalTuple, NumType, RIOResampling, + WarpResampling, ) from rio_tiler.utils import ( _validate_shape_input, @@ -786,3 +787,44 @@ def get_coverage_array( ).astype("float32") return cover_array.sum(-1).sum(1) / (cover_scale**2) + + def reproject( + self, + dst_crs: CRS, + resolution: Optional[Tuple[float, float]] = None, + reproject_method: WarpResampling = "nearest", + ) -> "ImageData": + """Reproject data and mask.""" + dst_transform, w, h = calculate_default_transform( + self.crs, + dst_crs, + self.width, + self.height, + *self.bounds, + resolution=resolution, + ) + + destination = numpy.ma.masked_array( + numpy.zeros((self.count, h, w), dtype=self.array.dtype), + ) + destination, _ = reproject( + self.array, + destination, + src_transform=self.transform, + src_crs=self.crs, + dst_transform=dst_transform, + dst_crs=dst_crs, + resampling=Resampling[reproject_method], + ) + + bounds = array_bounds(h, w, dst_transform) + + return ImageData( + destination, + assets=self.assets, + crs=dst_crs, + bounds=bounds, + band_names=self.band_names, + metadata=self.metadata, + dataset_statistics=self.dataset_statistics, + ) diff --git a/tests/test_models.py b/tests/test_models.py index 988b300a..fc17452c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -466,3 +466,71 @@ def test_image_encoding_error(): """Test ImageData error when using bad data array shape.""" with pytest.raises(InvalidFormat): ImageData(numpy.zeros((5, 256, 256), dtype="uint8")).render(img_format="PNG") + + +def test_image_reproject(): + """Test basic reproject functionality.""" + data = numpy.zeros((1, 256, 256), dtype="uint8") + data[0:256, 0:256] = 1 + mask = numpy.zeros((1, 256, 256), dtype="bool") + mask[0:100, 0:100] = True + + # Create test image with WGS84 CRS + src_crs = CRS.from_epsg(4326) + img = ImageData( + numpy.ma.MaskedArray(data=data, mask=mask), + crs=src_crs, + bounds=(-95, 43, -92, 45), + metadata={"test": "value"}, + band_names=["band1"], + ) + + # Test re-projection to Web Mercator + dst_crs = CRS.from_epsg(3857) + + reprojected = img.reproject(dst_crs) + assert reprojected.crs == dst_crs + assert reprojected.count == 1 + assert reprojected.width != 256 + assert reprojected.height != 256 + assert reprojected.array[0, 0, 0].data == 0 + assert reprojected.array.data[0, -10, -10] == 1 + assert reprojected.array.mask.shape[0] == 1 + assert reprojected.array.mask[0, 0, 0] + assert not reprojected.array.mask[0, -10, -10] + assert reprojected.metadata == img.metadata + assert reprojected.band_names == img.band_names + + # Test no re-projection when CRS is the same + same_crs = img.reproject(src_crs) + assert same_crs.crs == src_crs + assert same_crs.transform == img.transform + numpy.testing.assert_array_equal(same_crs.array, img.array) + + # Test with different resampling method + reprojected_bilinear = img.reproject(dst_crs, reproject_method="bilinear") + with numpy.testing.assert_raises(AssertionError): + numpy.testing.assert_array_equal(reprojected_bilinear.array, img.array) + + # With MultiBands + data = numpy.zeros((3, 256, 256), dtype="uint8") + data[:, 0:256, 0:256] = 1 + mask = numpy.zeros((3, 256, 256), dtype="bool") + mask[:, 0:100, 0:100] = True + + img = ImageData( + numpy.ma.MaskedArray(data=data, mask=mask), + crs=src_crs, + bounds=(-95, 43, -92, 45), + ) + + reprojected = img.reproject(dst_crs) + assert reprojected.crs == dst_crs + assert reprojected.count == 3 + assert reprojected.width != 256 + assert reprojected.height != 256 + assert reprojected.array.data[:, 0, 0].tolist() == [0, 0, 0] + assert reprojected.array.data[:, -10, -10].tolist() == [1, 1, 1] + assert reprojected.array.mask.shape[0] == 3 + assert reprojected.array.mask[:, 0, 0].tolist() == [True, True, True] + assert reprojected.array.mask[:, -10, -10].tolist() == [False, False, False]