Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

fix methods to infer geom and bbox #3

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__pycache__
*.parquet
dist/
.vscode
63 changes: 37 additions & 26 deletions stac_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from typing import TypeVar

import dask
import pystac
import pandas as pd
import pyarrow
import pyarrow.parquet
import fsspec
import dask_geopandas
import fsspec
import pandas as pd
import pyarrow as pa
import pyproj
import pystac
import shapely.geometry

from shapely.ops import transform

T = TypeVar("T", pystac.Collection, pystac.Item)
SCHEMA_URI = "https://stac-extensions.github.io/table/v1.2.0/schema.json"
Expand Down Expand Up @@ -155,11 +155,13 @@ def generate(
or infer_datetime != InferDatetimeOptions.no
or proj is True
):
# # TODO: this doesn't actually work
# data = dask_geopandas.read_parquet(
# ds.files, storage_options={"filesystem": ds.filesystem}
# )
data = dask_geopandas.read_parquet(uri, storage_options=storage_options)
# # TODO: this doesn't actually work
# data = dask_geopandas.read_parquet(
# ds.files, storage_options={"filesystem": ds.filesystem}
# )
if data.spatial_partitions is None:
data.calculate_spatial_partitions()

columns = get_columns(ds)
template.properties["table:columns"] = columns
Expand All @@ -176,16 +178,27 @@ def generate(

extra_proj = {}
if infer_bbox:
src_crs = data.spatial_partitions.crs.to_epsg()
tf = pyproj.Transformer.from_crs(src_crs, 4326, always_xy=True)

bbox = data.spatial_partitions.unary_union.bounds
# TODO: may need to convert to epsg:4326
# NOTE: bbox of unary union will be stored under proj extension as projected
extra_proj["proj:bbox"] = bbox
template.bbox = bbox

# NOTE: bbox will be stored in pystsac.Item.bbox in EPSG:4326
bbox = transform(tf.transform, shapely.geometry.box(*bbox))
template.bbox = bbox.bounds

if infer_geometry:
geometry = shapely.geometry.mapping(data.unary_union.compute())
# TODO: may need to convert to epsg:4326
extra_proj["proj:geometry"] = geometry
template.geometry = geometry
# NOTE: geom under proj extension as projected
geometry = data.unary_union.compute()
extra_proj["proj:geometry"] = shapely.geometry.mapping(geometry)

# NOTE: geometry will be stored in pystsac.Item.geometry in EPSG:4326
src_crs = data.spatial_partitions.crs.to_epsg()
tf = pyproj.Transformer.from_crs(src_crs, 4326, always_xy=True)
geometry = transform(tf.transform, geometry)
template.geometry = shapely.geometry.mapping(geometry)

if infer_bbox and template.geometry is None:
# If bbox is set then geometry must be set as well.
Expand All @@ -200,7 +213,8 @@ def generate(
template.properties.update(**extra_proj, **proj)

if infer_datetime != InferDatetimeOptions.no and datetime_column is None:
raise ValueError("Must specify 'datetime_column' when 'infer_datetime != no'.")
msg = "Must specify 'datetime_column' when 'infer_datetime != no'."
raise ValueError(msg)

if infer_datetime == InferDatetimeOptions.midpoint:
values = dask.compute(data[datetime_column].min(), data[datetime_column].max())
Expand All @@ -210,7 +224,8 @@ def generate(
values = data[datetime_column].unique().compute()
n = len(values)
if n > 1:
raise ValueError(f"infer_datetime='unique', but {n} unique values found.")
msg = f"infer_datetime='unique', but {n} unique values found."
raise ValueError(msg)
template.properties["datetime"] = values[0].to_pydatetime()

if infer_datetime == InferDatetimeOptions.range:
Expand Down Expand Up @@ -258,7 +273,7 @@ def get_proj(ds):
return proj


def get_columns(ds: pyarrow.parquet.ParquetDataset) -> list:
def get_columns(ds: pa.parquet.ParquetDataset) -> list:
columns = []
fragment = ds.fragments[0]

Expand All @@ -273,13 +288,9 @@ def get_columns(ds: pyarrow.parquet.ParquetDataset) -> list:
return columns


def parquet_dataset_from_url(
url: str, storage_options
):
def parquet_dataset_from_url(url: str, storage_options):
fs, _, _ = fsspec.get_fs_token_paths(url, storage_options=storage_options)
pa_fs = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(fs))
pa_fs = pa.fs.PyFileSystem(pa.fs.FSSpecHandler(fs))
url2 = url.split("://", 1)[-1] # pyarrow doesn't auto-strip the prefix.
ds = pyarrow.parquet.ParquetDataset(
url2, filesystem=pa_fs, use_legacy_dataset=False
)
ds = pa.parquet.ParquetDataset(url2, filesystem=pa_fs, use_legacy_dataset=False)
return ds
18 changes: 14 additions & 4 deletions test_stac_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,25 @@
import shutil
from pathlib import Path

import pandas as pd
import dask_geopandas
import geopandas
import pandas as pd
import pyproj
import pystac
import pytest
import shapely.geometry

import stac_table


def is_valid_epsg(epsg_code):
try:
pyproj.CRS.from_user_input(epsg_code)
return True
except pyproj.exceptions.CRSError:
return False


@pytest.fixture
def ensure_clean():
yield
Expand Down Expand Up @@ -51,11 +61,11 @@ def test_generate_item(self, partition):
)

expected_columns = [
{"name": "pop_est", "type": "int64"},
{"name": "pop_est", "type": "double"},
{"name": "continent", "type": "byte_array"},
{"name": "name", "type": "byte_array"},
{"name": "iso_a3", "type": "byte_array"},
{"name": "gdp_md_est", "type": "double"},
{"name": "gdp_md_est", "type": "int64"},
{"name": "geometry", "type": "byte_array"},
]
assert result.properties["table:columns"] == expected_columns
Expand All @@ -67,7 +77,7 @@ def test_generate_item(self, partition):
assert asset.extra_fields["table:storage_options"] == {"storage_account": "foo"}

assert pystac.extensions.projection.SCHEMA_URI in result.stac_extensions
assert result.properties["proj:epsg"] == 4326
assert is_valid_epsg(result.properties["proj:epsg"])

def test_infer_bbox(self):
df = geopandas.GeoDataFrame(
Expand Down