Skip to content

Commit

Permalink
Write to Delta Lake (#58)
Browse files Browse the repository at this point in the history
* Move json equality logic outside of test_arrow

* Refactor to RawBatch and CleanBatch wrapper types

* Move _from_arrow functions to _api

* Update imports

* fix circular import

* keep deprecated api

* Add write-read test and fix typing

* add parquet tests

* fix ci

* Initial delta lake support

* Manual schema updates

* Add delta lake dep

* Fix export

* fix pyupgrade lint

* Add type hints

* any typing
  • Loading branch information
kylebarron authored Jun 5, 2024
1 parent 0ba3994 commit e43398b
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 30 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dynamic = ["version", "description"]
requires-python = ">=3.8"
dependencies = [
"ciso8601",
"deltalake",
"geopandas",
"packaging",
"pandas",
Expand Down
19 changes: 11 additions & 8 deletions stac_geoparquet/arrow/_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, Optional, Union
from typing import Any, Iterable, Iterator

import pyarrow as pa

Expand All @@ -11,10 +13,10 @@


def parse_stac_items_to_arrow(
items: Iterable[Dict[str, Any]],
items: Iterable[dict[str, Any]],
*,
chunk_size: int = 8192,
schema: Optional[Union[pa.Schema, InferredSchema]] = None,
schema: pa.Schema | InferredSchema | None = None,
) -> Iterable[pa.RecordBatch]:
"""Parse a collection of STAC Items to an iterable of :class:`pyarrow.RecordBatch`.
Expand Down Expand Up @@ -51,11 +53,11 @@ def parse_stac_items_to_arrow(


def parse_stac_ndjson_to_arrow(
path: Union[str, Path, Iterable[Union[str, Path]]],
path: str | Path | Iterable[str | Path],
*,
chunk_size: int = 65536,
schema: Optional[pa.Schema] = None,
limit: Optional[int] = None,
schema: pa.Schema | None = None,
limit: int | None = None,
) -> Iterator[pa.RecordBatch]:
"""
Convert one or more newline-delimited JSON STAC files to a generator of Arrow
Expand Down Expand Up @@ -83,6 +85,7 @@ def parse_stac_ndjson_to_arrow(
if schema is None:
inferred_schema = InferredSchema()
inferred_schema.update_from_json(path, chunk_size=chunk_size, limit=limit)
inferred_schema.manual_updates()
yield from parse_stac_ndjson_to_arrow(
path, chunk_size=chunk_size, schema=inferred_schema
)
Expand All @@ -103,7 +106,7 @@ def stac_table_to_items(table: pa.Table) -> Iterable[dict]:


def stac_table_to_ndjson(
table: pa.Table, dest: Union[str, Path, os.PathLike[bytes]]
table: pa.Table, dest: str | Path | os.PathLike[bytes]
) -> None:
"""Write a STAC Table to a newline-delimited JSON file."""
for batch in table.to_batches():
Expand All @@ -112,7 +115,7 @@ def stac_table_to_ndjson(


def stac_items_to_arrow(
items: Iterable[Dict[str, Any]], *, schema: Optional[pa.Schema] = None
items: Iterable[dict[str, Any]], *, schema: pa.Schema | None = None
) -> pa.RecordBatch:
"""Convert dicts representing STAC Items to Arrow
Expand Down
34 changes: 34 additions & 0 deletions stac_geoparquet/arrow/_delta_lake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import itertools
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable

import pyarrow as pa
from deltalake import write_deltalake

from stac_geoparquet.arrow._api import parse_stac_ndjson_to_arrow
from stac_geoparquet.arrow._to_parquet import create_geoparquet_metadata

if TYPE_CHECKING:
from deltalake import DeltaTable


def parse_stac_ndjson_to_delta_lake(
input_path: str | Path | Iterable[str | Path],
table_or_uri: str | Path | DeltaTable,
*,
chunk_size: int = 65536,
schema: pa.Schema | None = None,
limit: int | None = None,
**kwargs: Any,
) -> None:
batches_iter = parse_stac_ndjson_to_arrow(
input_path, chunk_size=chunk_size, schema=schema, limit=limit
)
first_batch = next(batches_iter)
schema = first_batch.schema.with_metadata(
create_geoparquet_metadata(pa.Table.from_batches([first_batch]))
)
combined_iter = itertools.chain([first_batch], batches_iter)
write_deltalake(table_or_uri, combined_iter, schema=schema, engine="rust", **kwargs)
56 changes: 52 additions & 4 deletions stac_geoparquet/arrow/_schema/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Sequence, Union
from typing import Any, Iterable, Sequence

import pyarrow as pa

Expand Down Expand Up @@ -27,10 +29,10 @@ def __init__(self) -> None:

def update_from_json(
self,
path: Union[str, Path, Iterable[Union[str, Path]]],
path: str | Path | Iterable[str | Path],
*,
chunk_size: int = 65536,
limit: Optional[int] = None,
limit: int | None = None,
) -> None:
"""
Update this inferred schema from one or more newline-delimited JSON STAC files.
Expand All @@ -45,11 +47,57 @@ def update_from_json(
for batch in read_json_chunked(path, chunk_size=chunk_size, limit=limit):
self.update_from_items(batch)

def update_from_items(self, items: Sequence[Dict[str, Any]]) -> None:
def update_from_items(self, items: Sequence[dict[str, Any]]) -> None:
"""Update this inferred schema from a sequence of STAC Items."""
self.count += len(items)
current_schema = StacJsonBatch.from_dicts(items, schema=None).inner.schema
new_schema = pa.unify_schemas(
[self.inner, current_schema], promote_options="permissive"
)
self.inner = new_schema

def manual_updates(self) -> None:
schema = self.inner
properties_field = schema.field("properties")
properties_schema = pa.schema(properties_field.type)

# The datetime column can be inferred as `null` in the case of a Collection with
# start_datetime and end_datetime. But `null` is incompatible with Delta Lake,
# so we coerce to a Timestamp type.
if pa.types.is_null(properties_schema.field("datetime").type):
field_idx = properties_schema.get_field_index("datetime")
properties_schema = properties_schema.set(
field_idx,
properties_schema.field(field_idx).with_type(
pa.timestamp("us", tz="UTC")
),
)

if "proj:epsg" in properties_schema.names and pa.types.is_null(
properties_schema.field("proj:epsg").type
):
field_idx = properties_schema.get_field_index("proj:epsg")
properties_schema = properties_schema.set(
field_idx,
properties_schema.field(field_idx).with_type(pa.int64()),
)

if "proj:wkt2" in properties_schema.names and pa.types.is_null(
properties_schema.field("proj:wkt2").type
):
field_idx = properties_schema.get_field_index("proj:wkt2")
properties_schema = properties_schema.set(
field_idx,
properties_schema.field(field_idx).with_type(pa.string()),
)

# Note: proj:projjson can also be null, but we don't have a type we can cast
# that to.

properties_idx = schema.get_field_index("properties")
updated_schema = schema.set(
properties_idx,
properties_field.with_type(pa.struct(properties_schema)),
)

self.inner = updated_schema
21 changes: 12 additions & 9 deletions stac_geoparquet/arrow/_to_parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Iterable

import pyarrow as pa
import pyarrow.parquet as pq
Expand All @@ -11,11 +13,12 @@


def parse_stac_ndjson_to_parquet(
input_path: Union[str, Path, Iterable[Union[str, Path]]],
output_path: Union[str, Path],
input_path: str | Path | Iterable[str | Path],
output_path: str | Path,
*,
chunk_size: int = 65536,
schema: Optional[Union[pa.Schema, InferredSchema]] = None,
schema: pa.Schema | InferredSchema | None = None,
limit: int | None = None,
**kwargs: Any,
) -> None:
"""Convert one or more newline-delimited JSON STAC files to GeoParquet
Expand All @@ -32,11 +35,11 @@ def parse_stac_ndjson_to_parquet(
"""

batches_iter = parse_stac_ndjson_to_arrow(
input_path, chunk_size=chunk_size, schema=schema
input_path, chunk_size=chunk_size, schema=schema, limit=limit
)
first_batch = next(batches_iter)
schema = first_batch.schema.with_metadata(
_create_geoparquet_metadata(pa.Table.from_batches([first_batch]))
create_geoparquet_metadata(pa.Table.from_batches([first_batch]))
)
with pq.ParquetWriter(output_path, schema, **kwargs) as writer:
writer.write_batch(first_batch)
Expand All @@ -54,13 +57,13 @@ def to_parquet(table: pa.Table, where: Any, **kwargs: Any) -> None:
where: The destination for saving.
"""
metadata = table.schema.metadata or {}
metadata.update(_create_geoparquet_metadata(table))
metadata.update(create_geoparquet_metadata(table))
table = table.replace_schema_metadata(metadata)

pq.write_table(table, where, **kwargs)


def _create_geoparquet_metadata(table: pa.Table) -> dict[bytes, bytes]:
def create_geoparquet_metadata(table: pa.Table) -> dict[bytes, bytes]:
# TODO: include bbox of geometries
column_meta = {
"encoding": "WKB",
Expand All @@ -77,7 +80,7 @@ def _create_geoparquet_metadata(table: pa.Table) -> dict[bytes, bytes]:
}
},
}
geo_meta: Dict[str, Any] = {
geo_meta: dict[str, Any] = {
"version": "1.1.0-dev",
"columns": {"geometry": column_meta},
"primary_column": "geometry",
Expand Down
7 changes: 4 additions & 3 deletions stac_geoparquet/cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import argparse
import logging
import sys
import os
from typing import List, Optional

from stac_geoparquet import pc_runner

logger = logging.getLogger("stac_geoparquet.pgstac_reader")


def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
def parse_args(args: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-protocol",
Expand Down Expand Up @@ -90,7 +91,7 @@ def setup_logging() -> None:
}


def main(inp: Optional[List[str]] = None) -> int:
def main(inp: list[str] | None = None) -> int:
import azure.data.tables

args = parse_args(inp)
Expand Down
14 changes: 8 additions & 6 deletions stac_geoparquet/json_reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""Return an iterator of items from an ndjson, a json array of items, or a featurecollection of items."""

from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Sequence, Union
from typing import Any, Iterable, Sequence

import orjson

from stac_geoparquet.arrow._util import batched_iter


def read_json(
path: Union[str, Path, Iterable[Union[str, Path]]],
) -> Iterable[Dict[str, Any]]:
path: str | Path | Iterable[str | Path],
) -> Iterable[dict[str, Any]]:
"""Read a json or ndjson file."""
if isinstance(path, (str, Path)):
path = [path]
Expand Down Expand Up @@ -39,10 +41,10 @@ def read_json(


def read_json_chunked(
path: Union[str, Path, Iterable[Union[str, Path]]],
path: str | Path | Iterable[str | Path],
chunk_size: int,
*,
limit: Optional[int] = None,
) -> Iterable[Sequence[Dict[str, Any]]]:
limit: int | None = None,
) -> Iterable[Sequence[dict[str, Any]]]:
"""Read from a JSON or NDJSON file in chunks of `chunk_size`."""
return batched_iter(read_json(path), chunk_size, limit=limit)
46 changes: 46 additions & 0 deletions tests/test_delta_lake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from pathlib import Path

import pytest
from deltalake import DeltaTable

from stac_geoparquet.arrow import stac_table_to_items
from stac_geoparquet.arrow._delta_lake import parse_stac_ndjson_to_delta_lake

from .json_equals import assert_json_value_equal

HERE = Path(__file__).parent

TEST_COLLECTIONS = [
"3dep-lidar-copc",
# "3dep-lidar-dsm",
"cop-dem-glo-30",
"io-lulc-annual-v02",
# "io-lulc",
"landsat-c2-l1",
"landsat-c2-l2",
"naip",
"planet-nicfi-analytic",
"sentinel-1-rtc",
"sentinel-2-l2a",
"us-census",
]


@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS)
def test_round_trip_via_delta_lake(collection_id: str, tmp_path: Path):
path = HERE / "data" / f"{collection_id}-pc.json"
out_path = tmp_path / collection_id
parse_stac_ndjson_to_delta_lake(path, out_path)

# Read back into table and convert to json
dt = DeltaTable(out_path)
table = dt.to_pyarrow_table()
items_result = list(stac_table_to_items(table))

# Compare with original json
with open(HERE / "data" / f"{collection_id}-pc.json") as f:
items = json.load(f)

for result, expected in zip(items_result, items):
assert_json_value_equal(result, expected, precision=0)

0 comments on commit e43398b

Please sign in to comment.