From 61282d1347c0c34b0915eedf21b8fa444e2f3f52 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 13 May 2024 17:39:55 -0400 Subject: [PATCH] Include GeoArrow metadata on constructed Arrow table (#52) --- pyproject.toml | 1 + stac_geoparquet/arrow/_to_arrow.py | 18 ++++++++++++++++++ tests/test_arrow.py | 14 ++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a9f4df0..0003844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "packaging", "pandas", "pyarrow", + "pyproj", "pystac", "shapely", ] diff --git a/stac_geoparquet/arrow/_to_arrow.py b/stac_geoparquet/arrow/_to_arrow.py index b5b1f06..5688e31 100644 --- a/stac_geoparquet/arrow/_to_arrow.py +++ b/stac_geoparquet/arrow/_to_arrow.py @@ -13,6 +13,8 @@ import shapely import shapely.geometry +from stac_geoparquet.arrow._to_parquet import WGS84_CRS_JSON + def _chunks( lst: Sequence[Dict[str, Any]], n: int @@ -109,6 +111,7 @@ def _process_arrow_table(table: pa.Table, *, downcast: bool = True) -> pa.Table: table = _bring_properties_to_top_level(table) table = _convert_timestamp_columns(table) table = _convert_bbox_to_struct(table, downcast=downcast) + table = _assign_geoarrow_metadata(table) return table @@ -362,3 +365,18 @@ def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool) -> pa.Table: new_chunks.append(struct_arr) return table.set_column(bbox_col_idx, "bbox", new_chunks) + + +def _assign_geoarrow_metadata(table: pa.Table) -> pa.Table: + """Tag the primary geometry column with `geoarrow.wkb` on the field metadata.""" + existing_field_idx = table.schema.get_field_index("geometry") + existing_field = table.schema.field(existing_field_idx) + ext_metadata = {"crs": WGS84_CRS_JSON} + field_metadata = { + b"ARROW:extension:name": b"geoarrow.wkb", + b"ARROW:extension:metadata": json.dumps(ext_metadata).encode("utf-8"), + } + new_field = existing_field.with_metadata(field_metadata) + return table.set_column( + existing_field_idx, new_field, table.column(existing_field_idx) + ) diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 0cc1803..e3f1291 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -210,6 +210,20 @@ def test_round_trip(collection_id: str): assert_json_value_equal(result, expected, precision=0) +def test_table_contains_geoarrow_metadata(): + collection_id = "naip" + with open(HERE / "data" / f"{collection_id}-pc.json") as f: + items = json.load(f) + + table = parse_stac_items_to_arrow(items) + field_meta = table.schema.field("geometry").metadata + assert field_meta[b"ARROW:extension:name"] == b"geoarrow.wkb" + assert json.loads(field_meta[b"ARROW:extension:metadata"])["crs"]["id"] == { + "authority": "EPSG", + "code": 4326, + } + + def test_to_arrow_deprecated(): with pytest.warns(FutureWarning): import stac_geoparquet.to_arrow