diff --git a/ibis-server/app/model/metadata/dto.py b/ibis-server/app/model/metadata/dto.py index 85ec3f8d9..74dc768b1 100644 --- a/ibis-server/app/model/metadata/dto.py +++ b/ibis-server/app/model/metadata/dto.py @@ -51,6 +51,11 @@ class RustWrenEngineColumnType(Enum): TIME = "TIME" NULL = "NULL" + # Extension types + ## PostGIS + GEOMETRY = "GEOMETRY" + GEOGRAPHY = "GEOGRAPHY" + class Column(BaseModel): name: str diff --git a/ibis-server/app/model/metadata/postgres.py b/ibis-server/app/model/metadata/postgres.py index 6a5b5720d..9a4324a98 100644 --- a/ibis-server/app/model/metadata/postgres.py +++ b/ibis-server/app/model/metadata/postgres.py @@ -11,6 +11,92 @@ from app.model.metadata.metadata import Metadata +class ExtensionHandler: + def __init__(self, connection): + self.connection = connection + + self.handlers = { + "postgis": self.postgis_handler, + } + + def augment(self, tables: list[Table]) -> list[Table]: + # Get the list of extensions from the database + extensions = self.get_extensions() + + # Iterate through the extensions and call the appropriate handler + for ext in extensions: + ext_name = ext["extension_name"] + schema_name = ext["schema_name"] + + if ext_name in self.handlers: + handler = self.handlers[ext_name] + tables = handler(tables, schema_name) + + return tables + + def get_extensions(self) -> list[str]: + sql = """ + SELECT + e.extname AS extension_name, + n.nspname AS schema_name + FROM + pg_extension e + JOIN + pg_namespace n ON n.oid = e.extnamespace; + """ + df = self.connection.sql(sql).to_pandas() + if df.empty: + return [] + response = df.to_dict(orient="records") + return response + + def postgis_handler(self, tables: list[Table], schema_name: str) -> list[Table]: + # Get the list of geometry and geography columns + sql = f""" + SELECT + f_table_schema, + f_table_name, + f_geometry_column AS column_name, + 'geometry' AS column_type + FROM + {schema_name}.geometry_columns + UNION ALL + SELECT + f_table_schema, + f_table_name, + f_geography_column AS column_name, + 'geography' AS column_type + FROM + {schema_name}.geography_columns; + """ + response = self.connection.sql(sql).to_pandas().to_dict(orient="records") + + # Update tables + for row in response: + # TODO: Might want to use a global `_format_postgres_compact_table_name` function. + table_name = f"{row['f_table_schema']}.{row['f_table_name']}" + table = tables[table_name] + for column in table.columns: + column.type = str( + self._transform_postgres_column_type(row["column_type"]) + ) + break + + return tables + + def _transform_postgres_column_type(self, data_type): + # lower case the data_type + data_type = data_type.lower() + + # Extension types + switcher = { + "geometry": RustWrenEngineColumnType.GEOMETRY, + "geography": RustWrenEngineColumnType.GEOGRAPHY, + } + + return switcher.get(data_type, RustWrenEngineColumnType.UNKNOWN) + + class PostgresMetadata(Metadata): def __init__(self, connection_info: PostgresConnectionInfo): super().__init__(connection_info) @@ -80,6 +166,8 @@ def get_table_list(self) -> list[Table]: properties=None, ) ) + extension_handler = ExtensionHandler(self.connection) + unique_tables = extension_handler.augment(unique_tables) return list(unique_tables.values()) def get_constraints(self) -> list[Constraint]: diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index 4ad773161..e04554051 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -25,6 +25,8 @@ httpx = "0.28.1" python-dotenv = "1.0.1" orjson = "3.10.16" pandas = "2.2.3" +geopandas = "^1.0.1" +geoalchemy2 = "^0.17.1" sqlglot = { extras = [ "rs", ], version = ">=23.4,<26.5" } # the version should follow the ibis-framework diff --git a/ibis-server/tests/resource/tpch/data/cities_geometry.parquet b/ibis-server/tests/resource/tpch/data/cities_geometry.parquet new file mode 100644 index 000000000..f10c6e5b1 Binary files /dev/null and b/ibis-server/tests/resource/tpch/data/cities_geometry.parquet differ diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index 6e868da9f..3e9be35d9 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -1,6 +1,7 @@ import base64 from urllib.parse import quote_plus, urlparse +import geopandas as gpd import orjson import pandas as pd import psycopg @@ -149,6 +150,22 @@ def postgres(request) -> PostgresContainer: return pg +# PostGIS only provides the amd64 images. To run the PostGIS tests on ARM devices, +# please manually download the image by using the command below. +# docker pull --platform linux/amd64 postgis/postgis:16-3.5-alpine +@pytest.fixture(scope="module") +def postgis(request) -> PostgresContainer: + pg = PostgresContainer("postgis/postgis:16-3.5-alpine").start() + engine = sqlalchemy.create_engine(pg.get_connection_url()) + with engine.begin() as conn: + conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis")) + gpd.read_parquet( + file_path("resource/tpch/data/cities_geometry.parquet") + ).to_postgis("cities_geometry", engine, index=False) + request.addfinalizer(pg.stop) + return pg + + async def test_query(client, manifest_str, postgres: PostgresContainer): connection_info = _to_connection_info(postgres) response = await client.post( @@ -1050,6 +1067,25 @@ async def test_model_substitute_non_existent_column( assert 'column "x" does not exist' in response.text +async def test_postgis_geometry(client, manifest_str, postgis: PostgresContainer): + connection_info = _to_connection_info(postgis) + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": ( + "SELECT ST_Distance(a.geometry, b.geometry) AS distance " + "FROM cities_geometry a, cities_geometry b " + "WHERE a.\"City\" = 'London' AND b.\"City\" = 'New York'" + ), + }, + ) + assert response.status_code == 200 + result = response.json() + assert result["data"][0] == ["74.6626535"] + + def _to_connection_info(pg: PostgresContainer): return { "host": pg.get_container_host_ip(),