diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 276d6971..71b6c663 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -29,7 +29,9 @@ import trino from tests.integration.conftest import trino_version from trino import constants +from trino.client import InlineSegment from trino.client import SegmentIterator +from trino.client import SpooledSegment from trino.dbapi import Cursor from trino.dbapi import DescribeOutput from trino.dbapi import TimeBoundLRUCache @@ -1883,9 +1885,17 @@ def test_segments_cursor(trino_connection): row_mapper = RowMapperFactory().create(columns=cur._query.columns, legacy_primitive_types=False) total = 0 for segment in segments: + assert isinstance(segment.segment, (InlineSegment, SpooledSegment)), ( + f"Expected InlineSegment or SpooledSegment, got {type(segment.segment)}" + ) assert segment.encoding == trino_connection._client_session.encoding - assert isinstance(segment.segment.uri, str), f"Expected string for uri, got {segment.segment.uri}" - assert isinstance(segment.segment.ack_uri, str), f"Expected string for ack_uri, got {segment.segment.ack_uri}" + if isinstance(segment.segment, SpooledSegment): + assert isinstance(segment.segment.uri, str), ( + f"Expected string for uri, got {type(segment.segment.uri)}" + ) + assert isinstance(segment.segment.ack_uri, str), ( + f"Expected string for ack_uri, got {type(segment.segment.ack_uri)}" + ) total += len(list(SegmentIterator(segment, row_mapper))) assert total == 300875, f"Expected total rows 300875, got {total}"