Skip to content

feat: Apply post_process automatically to all stream types #3023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ unfixable = [
"INP", # Allow implicit namespace packages in tests
"PLR2004",
"S101",
"SLF001",
"SLF001", # Allow private method access in tests
"PLC2701", # Allow usage of private members in tests
"PLR6301", # Don't suggest making test methods static, etc.
]
Expand Down
16 changes: 12 additions & 4 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,13 +1145,16 @@ def _sync_records( # noqa: C901
Yields:
Each record from the source.
"""
# Type definitions
context_element: types.Context | None
record: types.Record | None
context_list: list[types.Context] | list[dict] | None

# Initialize metrics
record_counter = metrics.record_counter(self.name)
timer = metrics.sync_timer(self.name)

record_index = 0
context_element: types.Context | None
context_list: list[types.Context] | list[dict] | None
context_list = [context] if context is not None else self.partitions
selected = self.selected

Expand All @@ -1178,6 +1181,11 @@ def _sync_records( # noqa: C901
record, child_context = record_result
else:
record = record_result

record = self.post_process(record, current_context)
if record is None:
continue

try:
self._process_record(
record,
Expand Down Expand Up @@ -1419,7 +1427,7 @@ def generate_child_contexts(
def get_records(
self,
context: types.Context | None,
) -> t.Iterable[dict | tuple[dict, dict | None]]:
) -> t.Iterable[types.Record | tuple[dict, dict | None]]:
"""Abstract record generator function. Must be overridden by the child class.

Each record emitted should be a dictionary of property names to their values.
Expand Down Expand Up @@ -1491,7 +1499,7 @@ def post_process( # noqa: PLR6301
self,
row: types.Record,
context: types.Context | None = None, # noqa: ARG002
) -> dict | None:
) -> types.Record | None:
"""As needed, append or transform raw data to match expected structure.

Optional. This method gives developers an opportunity to "clean up" the results
Expand Down
7 changes: 1 addition & 6 deletions singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,12 +622,7 @@ def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
Yields:
One item per (possibly processed) record in the API.
"""
for record in self.request_records(context):
transformed_record = self.post_process(record, context)
if transformed_record is None:
# Record filtered out during post_process()
continue
yield transformed_record
yield from self.request_records(context)

# Abstract methods:

Expand Down
13 changes: 5 additions & 8 deletions singer_sdk/streams/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

if t.TYPE_CHECKING:
from singer_sdk.connectors.sql import FullyQualifiedName
from singer_sdk.helpers.types import Context
from singer_sdk.helpers.types import Context, Record
from singer_sdk.tap_base import Tap


Expand Down Expand Up @@ -168,7 +168,7 @@ def effective_schema(self) -> dict:
return super().effective_schema

# Get records from stream
def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
def get_records(self, context: Context | None) -> t.Iterable[Record]:
"""Return a generator of record-type dictionary objects.

If the stream has a replication_key value defined, records will be sorted by the
Expand Down Expand Up @@ -218,12 +218,9 @@ def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
query = query.limit(self.ABORT_AT_RECORD_COUNT + 1)

with self.connector._connect() as conn: # noqa: SLF001
for record in conn.execute(query).mappings():
transformed_record = self.post_process(dict(record))
if transformed_record is None:
# Record filtered out during post_process()
continue
yield transformed_record
for row in conn.execute(query).mappings():
# https://github.com/sqlalchemy/sqlalchemy/discussions/10053#discussioncomment-6344965
yield dict(row)

@property
def is_sorted(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def get_records(
context: dict | None,
) -> t.Iterable[dict[str, t.Any]]:
"""Generate records."""
yield {"id": 1, "value": "Egypt"}
yield {"id": 2, "value": "Germany"}
yield {"id": 3, "value": "India"}
yield {"id": 1, "value": "Egypt", "updatedAt": "2021-01-01T00:00:00Z"}
yield {"id": 2, "value": "Germany", "updatedAt": "2021-01-01T00:00:01Z"}
yield {"id": 3, "value": "India", "updatedAt": "2021-01-01T00:00:02Z"}

@contextmanager
def with_replication_method(self, method: str | None) -> t.Iterator[None]:
Expand Down
38 changes: 38 additions & 0 deletions tests/core/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import requests_mock

from singer_sdk import Stream, Tap
from singer_sdk.helpers.types import Context, Record
from tests.core.conftest import SimpleTestTap

CONFIG_START_DATE = "2021-01-01"
Expand Down Expand Up @@ -741,3 +742,40 @@ def discover_streams(self):
assert all(
tap.streams[stream].selected is selection[stream] for stream in selection
)


def test_post_process_drops_record(tap: Tap):
"""Test post-processing is applied to records."""

class DropsRecord(SimpleTestStream):
def post_process(
self,
record: Record,
context: Context | None, # noqa: ARG002
) -> Record | None:
# Drop even IDs
return None if record["id"] % 2 == 0 else record

stream = DropsRecord(tap)
records = list(stream._sync_records(None, write_messages=False))
assert records == [
{"id": 1, "value": "Egypt", "updatedAt": "2021-01-01T00:00:00Z"},
{"id": 3, "value": "India", "updatedAt": "2021-01-01T00:00:02Z"},
]


def test_post_process_transforms_record(tap: Tap):
"""Test post-processing is applied to records."""

class TransformsRecord(SimpleTestStream):
def post_process(
self,
record: Record,
context: Context | None, # noqa: ARG002
) -> Record | None:
record["extra"] = "transformed"
return record

stream = TransformsRecord(tap)
records = stream._sync_records(None, write_messages=False)
assert all(record["extra"] == "transformed" for record in records)