Skip to content

Commit

Permalink
Add reproject method to ImageData class with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuelmathot committed Jan 25, 2025
1 parent 4c1b2a6 commit 68e9a41
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 2 deletions.
48 changes: 46 additions & 2 deletions rio_tiler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
from numpy.typing import NDArray
from pydantic import BaseModel
from rasterio import windows
import rasterio
from rasterio.coords import BoundingBox
from rasterio.crs import CRS
from rasterio.dtypes import dtype_ranges
from rasterio.enums import ColorInterp
from rasterio.enums import ColorInterp, Resampling
from rasterio.errors import NotGeoreferencedWarning
from rasterio.features import rasterize
from rasterio.io import MemoryFile
from rasterio.plot import reshape_as_image
from rasterio.transform import from_bounds
from rasterio.warp import transform_geom
from rasterio.warp import transform_geom, reproject, calculate_default_transform, transform_bounds
from typing_extensions import Self

from rio_tiler.colormap import apply_cmap
Expand Down Expand Up @@ -786,3 +787,46 @@ def get_coverage_array(
).astype("float32")

return cover_array.sum(-1).sum(1) / (cover_scale**2)

def reproject(
self,
dst_crs: CRS,
resolution: Optional[Tuple[float, float]] = None,
resampling_method: Resampling = Resampling.nearest,
) -> "ImageData":
"""Reproject data and mask."""
dst_transform, w, h = calculate_default_transform(
self.crs,
dst_crs,
self.width,
self.height,
*self.bounds,
resolution=resolution
)

destination = numpy.ma.MaskedArray(
numpy.zeros((self.count, h, w), dtype=self.array.dtype),
mask=numpy.zeros((self.count, h, w), dtype=bool),
)
reprojection, _ = reproject(
self.array,
destination=destination.data,
src_transform=self.transform,
src_crs=self.crs,
dst_transform=dst_transform,
dst_crs=dst_crs,
resampling=resampling_method,
masked=True,
)

new_bounds = transform_bounds(self.crs, dst_crs, *self.bounds)

return ImageData(
reprojection,
assets=self.assets,
crs=dst_crs,
bounds=new_bounds,
band_names=self.band_names,
metadata=self.metadata,
dataset_statistics=self.dataset_statistics,
)
115 changes: 115 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import warnings
from io import BytesIO

from affine import Affine
import numpy
import pytest
import rasterio
from rasterio.crs import CRS
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile
from rasterio.warp import Resampling

from rio_tiler.errors import (
InvalidDatatypeWarning,
Expand Down Expand Up @@ -466,3 +468,116 @@ def test_image_encoding_error():
"""Test ImageData error when using bad data array shape."""
with pytest.raises(InvalidFormat):
ImageData(numpy.zeros((5, 256, 256), dtype="uint8")).render(img_format="PNG")

def test_reproject_basic():
"""Test basic reproject functionality."""
data = numpy.zeros((1, 256, 256), dtype="uint8")
data[0:256, 0:256] = 1
mask = numpy.zeros((1, 256, 256), dtype="bool")
mask[0:256, 0:256] = False

# Create test image with WGS84 CRS
src_crs = CRS.from_epsg(4326)
img = ImageData(
numpy.ma.MaskedArray(data=data, mask=mask),
crs=src_crs,
bounds=(-180, -85, 180, 85),
)

# Test reprojection to Web Mercator
dst_crs = CRS.from_epsg(3857)

reprojected = img.reproject(dst_crs)

assert reprojected.crs == dst_crs
assert reprojected.array.shape == (1, 256, 256)
assert reprojected.array.mask.shape == (1, 256, 256)

# Test no reprojection when CRS is the same
same_crs = img.reproject(src_crs)
assert same_crs.crs == src_crs
assert same_crs.transform == img.transform
numpy.testing.assert_array_equal(same_crs.array.data, img.array.data)
numpy.testing.assert_array_equal(same_crs.array.mask, img.array.mask)

# Test with different resampling method
reprojected_bilinear = img.reproject(dst_crs, resampling_method=Resampling.bilinear)
assert reprojected_bilinear.crs == dst_crs
assert reprojected_bilinear.width == 256
assert reprojected_bilinear.height == 256
assert isinstance(reprojected_bilinear.transform, Affine)


def test_reproject_with_data():
"""Test reproject with actual data values."""
# Create a test pattern
data = numpy.zeros((1, 256, 256), dtype="uint8")
data[0, 40:170, 40:170] = 255 # Add a square of 255 values
mask = numpy.zeros((1, 256, 256), dtype=bool)
mask[0, 50:100, 50:100] = True # Add a masked point

src_crs = CRS.from_epsg(4326)
img = ImageData(
numpy.ma.MaskedArray(data=data, mask=mask),
crs=src_crs,
bounds=(-180, -85, 180, 85),
metadata={"test": "value"},
band_names=["band1"],
)
with open("img.png", "wb") as f:
f.write(img.render())

assert numpy.any(img.array.data > 0)

# Test reprojection to Web Mercator
dst_crs = CRS.from_epsg(3857)

reprojected = img.reproject(dst_crs)

with open("reprojected.png", "wb") as f:
f.write(reprojected.render())

# Check metadata preservation
assert reprojected.metadata == img.metadata
assert reprojected.band_names == img.band_names

# Check data and mask shapes
assert reprojected.array.data.shape == (1, 256, 256)
assert reprojected.array.mask.shape == (1, 256, 256)

# Verify some data is preserved (exact values may change due to resampling)
assert numpy.any(reprojected.array.data > 0)
# QUESTION: rasterio does not seemd to supprt masked arrays.
# Should we reproject the mask as well?
# assert numpy.any(reprojected.array.mask)


def test_reproject_multiband():
"""Test reproject with multiple bands."""
data = numpy.zeros((3, 10, 10))
data[0, 4:7, 4:7] = 1 # Red band
data[1, 3:8, 3:8] = 0.5 # Green band
data[2, 2:9, 2:9] = 0.25 # Blue band

mask = numpy.zeros((3, 10, 10), dtype=bool)
mask[0, 5, 5] = True

src_crs = CRS.from_epsg(4326)
img = ImageData(
numpy.ma.MaskedArray(data=data, mask=mask),
crs=src_crs,
bounds=(-180, -85, 180, 85),
band_names=["red", "green", "blue"],
)

dst_crs = CRS.from_epsg(3857)

reprojected = img.reproject(dst_crs)

assert reprojected.count == 3
assert reprojected.band_names == ["red", "green", "blue"]

# Check each band has unique values
assert not numpy.array_equal(reprojected.array.data[0], reprojected.array.data[1])
assert not numpy.array_equal(reprojected.array.data[1], reprojected.array.data[2])

0 comments on commit 68e9a41

Please sign in to comment.