Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally use pyarrow types in to_geodataframe #31

Merged
merged 3 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 120 additions & 33 deletions stac_geoparquet/stac_geoparquet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""
Generate geoparquet from a sequence of STAC items.
"""

from __future__ import annotations
import collections

from typing import Sequence, Any
from typing import Sequence, Any, Literal
import warnings

import pystac
import geopandas
import pandas as pd
import pyarrow as pa
import numpy as np
import shapely.geometry

Expand All @@ -16,7 +20,7 @@
from stac_geoparquet.utils import fix_empty_multipolygon

STAC_ITEM_TYPES = ["application/json", "application/geo+json"]

DTYPE_BACKEND = Literal["numpy_nullable", "pyarrow"]
SELF_LINK_COLUMN = "self_link"


Expand All @@ -31,7 +35,10 @@ def _fix_array(v):


def to_geodataframe(
items: Sequence[dict[str, Any]], add_self_link: bool = False
items: Sequence[dict[str, Any]],
add_self_link: bool = False,
dtype_backend: DTYPE_BACKEND | None = None,
datetime_precision: str = "ns",
) -> geopandas.GeoDataFrame:
"""
Convert a sequence of STAC items to a :class:`geopandas.GeoDataFrame`.
Expand All @@ -42,19 +49,72 @@ def to_geodataframe(
Parameters
----------
items: A sequence of STAC items.
add_self_link: Add the absolute link (if available) to the source STAC Item as a separate column named "self_link"
add_self_link: bool, default False
Add the absolute link (if available) to the source STAC Item
as a separate column named "self_link"
dtype_backend: {'pyarrow', 'numpy_nullable'}, optional
The dtype backend to use for storing arrays.

By default, this will use 'numpy_nullable' and emit a
FutureWarning that the default will change to 'pyarrow' in
the next release.

Set to 'numpy_nullable' to silence the warning and accept the
old behavior.

Set to 'pyarrow' to silence the warning and accept the new behavior.

There are some difference in the output as well: with
``dtype_backend="pyarrow"``, struct-like fields will explicitly
contain null values for fields that appear in only some of the
records. For example, given an ``assets`` like::

{
"a": {
"href": "a.tif",
},
"b": {
"href": "b.tif",
"title": "B",
}
}

The ``assets`` field of the output for the first row with
``dtype_backend="numpy_nullable"`` will be a Python dictionary with
just ``{"href": "a.tiff"}``.

With ``dtype_backend="pyarrow"``, this will be a pyarrow struct
with fields ``{"href": "a.tif", "title", None}``. pyarrow will
infer that the struct field ``asset.title`` is nullable.

datetime_precision: str, default "ns"
The precision to use for the datetime columns. For example,
"us" is microsecond and "ns" is nanosecond.

Returns
-------
The converted GeoDataFrame.
"""
items2 = []
items2 = collections.defaultdict(list)

for item in items:
item2 = {k: v for k, v in item.items() if k != "properties"}
keys = set(item) - {"properties", "geometry"}

for k in keys:
items2[k].append(item[k])

item_geometry = item["geometry"]
if item_geometry:
item_geometry = fix_empty_multipolygon(item_geometry)

items2["geometry"].append(item_geometry)

for k, v in item["properties"].items():
if k in item2:
raise ValueError("k", k)
item2[k] = v
if k in item:
msg = f"Key '{k}' appears in both 'properties' and the top level."
raise ValueError(msg)
items2[k].append(v)

if add_self_link:
self_href = None
for link in item["links"]:
Expand All @@ -65,23 +125,11 @@ def to_geodataframe(
):
self_href = link["href"]
break
item2[SELF_LINK_COLUMN] = self_href
items2.append(item2)

# Filter out missing geoms in MultiPolygons
# https://github.com/shapely/shapely/issues/1407
# geometry = [shapely.geometry.shape(x["geometry"]) for x in items2]

geometry = []
for item2 in items2:
item_geometry = item2["geometry"]
if item_geometry:
item_geometry = fix_empty_multipolygon(item_geometry) # type: ignore
geometry.append(item_geometry)

gdf = geopandas.GeoDataFrame(items2, geometry=geometry, crs="WGS84")
items2[SELF_LINK_COLUMN].append(self_href)

for column in [
# TODO: Ideally we wouldn't have to hard-code this list.
# Could we get it from the JSON schema.
DATETIME_COLUMNS = {
"datetime", # common metadata
"start_datetime",
"end_datetime",
Expand All @@ -90,9 +138,43 @@ def to_geodataframe(
"expires", # timestamps extension
"published",
"unpublished",
]:
if column in gdf.columns:
gdf[column] = pd.to_datetime(gdf[column], format="ISO8601")
}

items2["geometry"] = geopandas.array.from_shapely(items2["geometry"])

if dtype_backend is None:
msg = (
"The default argument for 'dtype_backend' will change from "
"'numpy_nullable' to 'pyarrow'. To keep the previous default "
"specify ``dtype_backend='numpy_nullable'``. To accept the future "
"behavior specify ``dtype_backend='pyarrow'."
)
warnings.warn(FutureWarning(msg))
dtype_backend = "numpy_nullable"

if dtype_backend == "pyarrow":
for k, v in items2.items():
if k in DATETIME_COLUMNS:
dt = pd.to_datetime(v, format="ISO8601").as_unit(datetime_precision)
items2[k] = pd.arrays.ArrowExtensionArray(pa.array(dt))

elif k != "geometry":
items2[k] = pd.arrays.ArrowExtensionArray(pa.array(v))

elif dtype_backend == "numpy_nullable":
for k, v in items2.items():
if k in DATETIME_COLUMNS:
items2[k] = pd.to_datetime(v, format="ISO8601").as_unit(
datetime_precision
)

if k in {"type", "stac_version", "id", "collection", SELF_LINK_COLUMN}:
items2[k] = pd.array(v, dtype="string")
else:
msg = f"Invalid 'dtype_backend={dtype_backend}'."
raise TypeError(msg)

gdf = geopandas.GeoDataFrame(items2, geometry="geometry", crs="WGS84")

columns = [
"type",
Expand All @@ -111,10 +193,6 @@ def to_geodataframe(
columns.remove(col)

gdf = pd.concat([gdf[columns], gdf.drop(columns=columns)], axis="columns")
for k in ["type", "stac_version", "id", "collection", SELF_LINK_COLUMN]:
if k in gdf:
gdf[k] = gdf[k].astype("string")

return gdf


Expand Down Expand Up @@ -144,12 +222,16 @@ def to_dict(record: dict) -> dict:

if k == SELF_LINK_COLUMN:
continue
elif k == "assets":
item[k] = {k2: v2 for k2, v2 in v.items() if v2 is not None}
elif k in top_level_keys:
item[k] = v
else:
properties[k] = v

item["geometry"] = shapely.geometry.mapping(item["geometry"])
if item["geometry"]:
item["geometry"] = shapely.geometry.mapping(item["geometry"])

item["properties"] = properties

return item
Expand All @@ -175,6 +257,11 @@ def to_item_collection(df: geopandas.GeoDataFrame) -> pystac.ItemCollection:
include=["datetime64[ns, UTC]", "datetime64[ns]"]
).columns
for k in datelike:
# %f isn't implemented in pyarrow
# https://github.com/apache/arrow/issues/20146
if isinstance(df2[k].dtype, pd.ArrowDtype):
df2[k] = df2[k].astype("datetime64[ns, utc]")

df2[k] = (
df2[k].dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ").fillna("").replace({"": None})
)
Expand Down
44 changes: 36 additions & 8 deletions stac_geoparquet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,27 @@


@functools.singledispatch
def assert_equal(result: Any, expected: Any) -> bool:
def assert_equal(result: Any, expected: Any, ignore_none: bool = False) -> bool:
raise TypeError(f"Invalid type {type(result)}")


@assert_equal.register(pystac.ItemCollection)
def assert_equal_ic(
result: pystac.ItemCollection, expected: pystac.ItemCollection
result: pystac.ItemCollection,
expected: pystac.ItemCollection,
ignore_none: bool = False,
) -> None:
assert type(result) == type(expected)
assert len(result) == len(expected)
assert result.extra_fields == expected.extra_fields
for a, b in zip(result.items, expected.items):
assert_equal(a, b)
assert_equal(a, b, ignore_none=ignore_none)


@assert_equal.register(pystac.Item)
def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
def assert_equal_item(
result: pystac.Item, expected: pystac.Item, ignore_none: bool = False
) -> None:
assert type(result) == type(expected)
assert result.id == expected.id
assert shapely.geometry.shape(result.geometry) == shapely.geometry.shape(
Expand All @@ -41,20 +45,44 @@ def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
expected_links = sorted(expected.links, key=lambda x: x.href)
assert len(result_links) == len(expected_links)
for a, b in zip(result_links, expected_links):
assert_equal(a, b)
assert_equal(a, b, ignore_none=ignore_none)

assert set(result.assets) == set(expected.assets)
for k in result.assets:
assert_equal(result.assets[k], expected.assets[k])
assert_equal(result.assets[k], expected.assets[k], ignore_none=ignore_none)


@assert_equal.register(pystac.Link)
@assert_equal.register(pystac.Asset)
def assert_link_equal(
result: pystac.Link | pystac.Asset, expected: pystac.Link | pystac.Asset
result: pystac.Link | pystac.Asset,
expected: pystac.Link | pystac.Asset,
ignore_none: bool = False,
) -> None:
assert type(result) == type(expected)
assert result.to_dict() == expected.to_dict()
resultd = result.to_dict()
expectedd = expected.to_dict()

left = {}

if ignore_none:
for k, v in resultd.items():
if v is None and k not in expectedd:
pass
elif isinstance(v, list) and k in expectedd:
out = []
for val in v:
if isinstance(val, dict):
out.append({k: v2 for k, v2 in val.items() if v2 is not None})
else:
out.append(val)
left[k] = out
else:
left[k] = v
else:
left = resultd

assert left == expectedd


def fix_empty_multipolygon(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_naip_item():
expected.remove_links(rel=pystac.RelType.SELF)
result.remove_links(rel=pystac.RelType.SELF)

assert_equal(result, expected)
assert_equal(result, expected, ignore_none=True)


def test_sentinel2_l2a():
Expand All @@ -139,7 +139,7 @@ def test_sentinel2_l2a():
result.remove_links(rel=pystac.RelType.SELF)

expected.remove_links(rel=pystac.RelType.LICENSE)
assert_equal(result, expected)
assert_equal(result, expected, ignore_none=True)


def test_generate_endpoints():
Expand Down
Loading
Loading