diff --git a/docs/classes/singer_sdk.connectors.BaseConnector.rst b/docs/classes/singer_sdk.connectors.BaseConnector.rst new file mode 100644 index 000000000..3ba703887 --- /dev/null +++ b/docs/classes/singer_sdk.connectors.BaseConnector.rst @@ -0,0 +1,8 @@ +singer_sdk.connectors.BaseConnector +=================================== + +.. currentmodule:: singer_sdk.connectors + +.. autoclass:: BaseConnector + :members: + :special-members: __init__, __call__ \ No newline at end of file diff --git a/docs/guides/custom-connector.md b/docs/guides/custom-connector.md new file mode 100644 index 000000000..cf0ec7e20 --- /dev/null +++ b/docs/guides/custom-connector.md @@ -0,0 +1,32 @@ +# Using a custom connector class + +The Singer SDK has a few built-in connector classes that are designed to work with a variety of sources: + +* [`SQLConnector`](../../classes/singer_sdk.SQLConnector) for SQL databases + +If you need to connect to a source that is not supported by one of these built-in connectors, you can create your own connector class. This guide will walk you through the process of creating a custom connector class. + +## Subclass `BaseConnector` + +The first step is to create a subclass of [`BaseConnector`](../../classes/singer_sdk.connectors.BaseConnector). This class is responsible for creating streams and handling the connection to the source. + +```python +from singer_sdk.connectors import BaseConnector + + +class MyConnector(BaseConnector): + pass +``` + +## Implement `get_connection` + +The [`get_connection`](http://127.0.0.1:5500/build/classes/singer_sdk.connectors.BaseConnector.html#singer_sdk.connectors.BaseConnector.get_connection) method is responsible for creating a connection to the source. It should return an object that implements the [context manager protocol](https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers), e.g. it has `__enter__` and `__exit__` methods. + +```python +from singer_sdk.connectors import BaseConnector + + +class MyConnector(BaseConnector): + def get_connection(self): + return MyConnection() +``` diff --git a/docs/guides/index.md b/docs/guides/index.md index 89e3a0c19..ee64b7208 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -7,6 +7,7 @@ The following pages contain useful information for developers building on top of porting pagination-classes +custom-connector custom-clis config-schema performance diff --git a/docs/reference.rst b/docs/reference.rst index 6522c8a36..9c0f72180 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -144,6 +144,15 @@ Batch batch.BaseBatcher batch.JSONLinesBatcher +Abstract Connector Classes +-------------------------- + +.. autosummary:: + :toctree: classes + :template: class.rst + + connectors.BaseConnector + Other ----- diff --git a/pyproject.toml b/pyproject.toml index 3342e3020..8e249b54c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ testing = [ "moto>=5.0.14", "pytest>=7.2.1", "pytest-benchmark>=4.0.0", + "pytest-httpserver>=1.1.3", "pytest-snapshot>=0.9.0", "pytest-subtests>=0.13.1", "pytz>=2022.2.1", diff --git a/singer_sdk/authenticators.py b/singer_sdk/authenticators.py index 916669ef7..715dda428 100644 --- a/singer_sdk/authenticators.py +++ b/singer_sdk/authenticators.py @@ -11,6 +11,7 @@ from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit import requests +from requests.auth import AuthBase from singer_sdk.helpers._util import utc_now @@ -593,3 +594,52 @@ def oauth_request_payload(self) -> dict: "RS256", ), } + + +class NoopAuth(AuthBase): + """No-op authenticator.""" + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + """Do nothing. + + Args: + r: The prepared request. + + Returns: + The unmodified prepared request. + """ + return r + + +class HeaderAuth(AuthBase): + """Header-based authenticator.""" + + def __init__( + self, + keyword: str, + value: str, + header: str = "Authorization", + ) -> None: + """Initialize the authenticator. + + Args: + keyword: The keyword to use in the header, e.g. "Bearer". + value: The value to use in the header, e.g. "my-token". + header: The header to add the keyword and value to, defaults to + ``"Authorization"``. + """ + self.keyword = keyword + self.value = value + self.header = header + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + """Add the header to the request. + + Args: + r: The prepared request. + + Returns: + The prepared request with the header added. + """ + r.headers[self.header] = f"{self.keyword} {self.value}" + return r diff --git a/singer_sdk/connectors/__init__.py b/singer_sdk/connectors/__init__.py index 32799417a..1c3916672 100644 --- a/singer_sdk/connectors/__init__.py +++ b/singer_sdk/connectors/__init__.py @@ -2,6 +2,8 @@ from __future__ import annotations +from ._http import HTTPConnector +from .base import BaseConnector from .sql import SQLConnector -__all__ = ["SQLConnector"] +__all__ = ["BaseConnector", "HTTPConnector", "SQLConnector"] diff --git a/singer_sdk/connectors/_http.py b/singer_sdk/connectors/_http.py new file mode 100644 index 000000000..b217b1e8b --- /dev/null +++ b/singer_sdk/connectors/_http.py @@ -0,0 +1,140 @@ +"""HTTP-based tap class for Singer SDK.""" + +from __future__ import annotations + +import typing as t + +import requests + +from singer_sdk.authenticators import NoopAuth +from singer_sdk.connectors.base import BaseConnector + +if t.TYPE_CHECKING: + import sys + from collections.abc import Mapping + + from requests.adapters import BaseAdapter + + if sys.version_info >= (3, 10): + from typing import TypeAlias # noqa: ICN003 + else: + from typing_extensions import TypeAlias + +_Auth: TypeAlias = t.Callable[[requests.PreparedRequest], requests.PreparedRequest] + + +class HTTPConnector(BaseConnector[requests.Session]): + """Base class for all HTTP-based connectors.""" + + def __init__(self, config: Mapping[str, t.Any] | None = None) -> None: + """Initialize the HTTP connector. + + Args: + config: Connector configuration parameters. + """ + super().__init__(config) + self.__session = self.get_session() + self.refresh_auth() + + def get_connection(self, *, authenticate: bool = True) -> requests.Session: + """Return a new HTTP session object. + + Adds adapters and optionally authenticates the session. + + Args: + authenticate: Whether to authenticate the request. + + Returns: + A new HTTP session object. + """ + for prefix, adapter in self.adapters.items(): + self.__session.mount(prefix, adapter) + + self.__session.auth = self.auth if authenticate else None + + return self.__session + + def get_session(self) -> requests.Session: # noqa: PLR6301 + """Return a new HTTP session object. + + Returns: + A new HTTP session object. + """ + return requests.Session() + + def get_authenticator(self) -> _Auth: # noqa: PLR6301 + """Authenticate the HTTP session. + + Returns: + An auth callable. + """ + return NoopAuth() + + def refresh_auth(self) -> None: + """Refresh the HTTP session authentication.""" + self.auth = self.get_authenticator() + + @property + def auth(self) -> _Auth: + """Return the HTTP session authenticator. + + Returns: + An auth callable. + """ + return self.__auth + + @auth.setter + def auth(self, auth: _Auth) -> None: + """Set the HTTP session authenticator. + + Args: + auth: An auth callable. + """ + self.__auth = auth + + @property + def session(self) -> requests.Session: + """Return the HTTP session object. + + Returns: + The HTTP session object. + """ + return self.__session + + @property + def adapters(self) -> dict[str, BaseAdapter]: + """Return a mapping of URL prefixes to adapter objects. + + Returns: + A mapping of URL prefixes to adapter objects. + """ + return {} + + @property + def default_request_kwargs(self) -> dict[str, t.Any]: + """Return default kwargs for HTTP requests. + + Returns: + A mapping of default kwargs for HTTP requests. + """ + return {} + + def request( + self, + *args: t.Any, + authenticate: bool = True, + **kwargs: t.Any, + ) -> requests.Response: + """Make an HTTP request. + + Args: + *args: Positional arguments to pass to the request method. + authenticate: Whether to authenticate the request. + **kwargs: Keyword arguments to pass to the request method. + + Returns: + The HTTP response object. + """ + with self.connect(authenticate=authenticate) as session: + kwargs = {**self.default_request_kwargs, **kwargs} + return session.request(*args, **kwargs) diff --git a/singer_sdk/connectors/base.py b/singer_sdk/connectors/base.py new file mode 100644 index 000000000..1e73cb609 --- /dev/null +++ b/singer_sdk/connectors/base.py @@ -0,0 +1,66 @@ +"""Base class for all connectors.""" + +from __future__ import annotations + +import abc +import typing as t +from contextlib import contextmanager + +if t.TYPE_CHECKING: + from collections.abc import Mapping + +_T = t.TypeVar("_T") + + +# class BaseConnector(abc.ABC, t.Generic[_T_co]): +class BaseConnector(abc.ABC, t.Generic[_T]): + """Base class for all connectors.""" + + def __init__(self, config: Mapping[str, t.Any] | None = None) -> None: + """Initialize the connector. + + Args: + config: Plugin configuration parameters. + """ + self._config = config or {} + + @property + def config(self) -> Mapping[str, t.Any]: + """Return the connector configuration. + + Returns: + A mapping of configuration parameters. + """ + return self._config + + @config.setter + def config(self, config: Mapping[str, t.Any]) -> None: + """Set the connector configuration. + + Args: + config: Plugin configuration parameters. + """ + self._config = config + + @contextmanager + def connect(self, *args: t.Any, **kwargs: t.Any) -> t.Generator[_T, None, None]: + """Connect to the destination. + + Args: + args: Positional arguments to pass to the connection method. + kwargs: Keyword arguments to pass to the connection method. + + Yields: + A connection object. + """ + yield self.get_connection(*args, **kwargs) + + @abc.abstractmethod + def get_connection(self, *args: t.Any, **kwargs: t.Any) -> _T: + """Connect to the destination. + + Args: + args: Positional arguments to pass to the connection method. + kwargs: Keyword arguments to pass to the connection method. + """ + ... diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index ba0d32465..3466f0a76 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -16,6 +16,7 @@ from sqlalchemy.engine import reflection from singer_sdk import typing as th +from singer_sdk.connectors.base import BaseConnector from singer_sdk.exceptions import ConfigValidationError from singer_sdk.helpers._compat import SingerSDKDeprecationWarning from singer_sdk.helpers._util import dump_json, load_json @@ -33,6 +34,8 @@ from typing import TypeAlias # noqa: ICN003 if t.TYPE_CHECKING: + from collections.abc import Mapping + from sqlalchemy.engine import Engine from sqlalchemy.engine.reflection import Inspector @@ -140,7 +143,10 @@ def __init__(self, *, use_singer_decimal: bool = False) -> None: self.use_singer_decimal = use_singer_decimal @classmethod - def from_config(cls: type[SQLToJSONSchema], config: dict) -> SQLToJSONSchema: + def from_config( + cls: type[SQLToJSONSchema], + config: Mapping[str, t.Any], + ) -> SQLToJSONSchema: """Create a new instance from a configuration dictionary. Override this to instantiate this converter with values from the tap's @@ -303,7 +309,7 @@ def __init__(self, *, max_varchar_length: int | None = None) -> None: @classmethod def from_config( cls: type[JSONSchemaToSQL], - config: dict, # noqa: ARG003 + config: Mapping[str, t.Any], # noqa: ARG003 *, max_varchar_length: int | None, ) -> JSONSchemaToSQL: @@ -551,7 +557,7 @@ def to_sql_type(self, schema: dict) -> sa.types.TypeEngine: return self.fallback_type() -class SQLConnector: # noqa: PLR0904 +class SQLConnector(BaseConnector[sa.engine.Connection]): # noqa: PLR0904 """Base class for SQLAlchemy-based connectors. The connector class serves as a wrapper around the SQL connection. @@ -587,7 +593,7 @@ class SQLConnector: # noqa: PLR0904 def __init__( self, - config: dict | None = None, + config: Mapping[str, t.Any] | None = None, sqlalchemy_url: str | None = None, ) -> None: """Initialize the SQL connector. @@ -596,18 +602,9 @@ def __init__( config: The parent tap or target object's config. sqlalchemy_url: Optional URL for the connection. """ - self._config: dict[str, t.Any] = config or {} + super().__init__(config=config) self._sqlalchemy_url: str | None = sqlalchemy_url or None - @property - def config(self) -> dict: - """If set, provides access to the tap or target config. - - Returns: - The settings as a dict. - """ - return self._config - @property def logger(self) -> logging.Logger: """Get logger. @@ -641,9 +638,35 @@ def jsonschema_to_sql(self) -> JSONSchemaToSQL: ) @contextmanager - def _connect(self) -> t.Iterator[sa.engine.Connection]: - with self._engine.connect().execution_options(stream_results=True) as conn: - yield conn + def _connect(self): # noqa: ANN202 + """Connect to the source. + + Yields: + A connection object. + """ + warnings.warn( + "`SQLConnector._connect` is deprecated. " + "Use `SQLConnector.connect` instead.", + DeprecationWarning, + stacklevel=2, + ) + with self.connect() as connection: + yield connection + + def get_connection( + self, + *, + stream_results: bool = True, + ) -> sa.engine.Connection: + """Return a new SQLAlchemy connection using the provided config. + + Args: + stream_results: Whether to stream results from the database. + + Returns: + A newly created SQLAlchemy connection object. + """ + return self._engine.connect().execution_options(stream_results=stream_results) @deprecated( "`SQLConnector.create_sqlalchemy_connection` is deprecated. " @@ -723,7 +746,7 @@ def sqlalchemy_url(self) -> str: return self._sqlalchemy_url - def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: # noqa: PLR6301 + def get_sqlalchemy_url(self, config: Mapping[str, t.Any]) -> str: # noqa: PLR6301 """Return the SQLAlchemy URL string. Developers can generally override just one of the following: @@ -1278,7 +1301,7 @@ def create_schema(self, schema_name: str) -> None: Args: schema_name: The target schema to create. """ - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(sa.schema.CreateSchema(schema_name)) def create_empty_table( @@ -1355,7 +1378,7 @@ def _create_empty_column( column_name=column_name, column_type=sql_type, ) - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(column_add_ddl) def prepare_schema(self, schema_name: str) -> None: @@ -1482,7 +1505,7 @@ def rename_column( column_name=old_name, new_column_name=new_name, ) - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(column_rename_ddl) def merge_sql_types( @@ -1789,7 +1812,7 @@ def _adapt_column_type( column_name=column_name, column_type=compatible_sql_type, ) - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute(alter_column_ddl) def serialize_json(self, obj: object) -> str: # noqa: PLR6301 @@ -1841,7 +1864,7 @@ def delete_old_versions( version_column_name: The name of the version column. current_version: The current ACTIVATE version of the table. """ - with self._connect() as conn, conn.begin(): + with self.connect() as conn, conn.begin(): conn.execute( sa.text( f"DELETE FROM {full_table_name} " # noqa: S608 diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index 9a82e59bd..78ec3d2ff 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -347,8 +347,7 @@ def bulk_insert_records( ] self.logger.info("Inserting with SQL: %s", insert_sql) - - with self.connector._connect() as conn, conn.begin(): # noqa: SLF001 + with self.connector.connect() as conn, conn.begin(): result = conn.execute(insert_sql, new_records) return result.rowcount @@ -427,7 +426,7 @@ def activate_version(self, new_version: int) -> None: bindparam("deletedate", value=deleted_at, type_=sa.types.DateTime), bindparam("version", value=new_version, type_=sa.types.Integer), ) - with self.connector._connect() as conn, conn.begin(): # noqa: SLF001 + with self.connector.connect() as conn, conn.begin(): conn.execute(query) diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index 524c8342a..13b39c312 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -17,6 +17,7 @@ from singer_sdk import metrics from singer_sdk.authenticators import SimpleAuthenticator +from singer_sdk.connectors import HTTPConnector from singer_sdk.exceptions import FatalAPIError, RetriableAPIError from singer_sdk.helpers._compat import SingerSDKDeprecationWarning from singer_sdk.helpers.jsonpath import extract_jsonpath @@ -48,7 +49,7 @@ class _HTTPStream(Stream, t.Generic[_TToken], metaclass=abc.ABCMeta): # noqa: P """Abstract base class for HTTP streams.""" _page_size: int = DEFAULT_PAGE_SIZE - _requests_session: requests.Session | None + _requests_session: requests.Session #: Response code reference for rate limit retries extra_retry_statuses: t.Sequence[int] = [HTTPStatus.TOO_MANY_REQUESTS] @@ -87,6 +88,7 @@ def __init__( path: str | None = None, *, http_method: str | None = None, + connector: HTTPConnector | None = None, ) -> None: """Initialize the HTTP stream. @@ -96,13 +98,25 @@ def __init__( name: Name of this stream. path: URL path for this entity stream. http_method: HTTP method to use for requests. + connector: Connector to use for HTTP requests. """ if path: self.path = path + + self.connector = connector or HTTPConnector() + + self._requests_session = self.connector.session + self._compiled_jsonpath = None + self._next_page_token_compiled_jsonpath = None self._http_method = http_method - self._requests_session = requests.Session() super().__init__(name=name, schema=schema, tap=tap) + # Override the connector's config with the stream's config + self.connector.config = self.config + + # Override the connector's auth with the stream's auth + self.connector.auth = self.authenticator + @staticmethod def _url_encode(val: str | datetime | bool | int | list[str]) -> str: # noqa: FBT001 """Encode the val argument as url-compatible string. @@ -183,8 +197,12 @@ def requests_session(self) -> requests.Session: Returns: The :class:`requests.Session` object for HTTP requests. """ - if not self._requests_session: - self._requests_session = requests.Session() + warn( + "The `requests_session` property is deprecated and will be removed in a " + "future release. Use the `connector` property instead.", + DeprecationWarning, + stacklevel=2, + ) return self._requests_session @cached_property @@ -319,11 +337,13 @@ def _request( Returns: TODO """ - response = self.requests_session.send( - prepared_request, - timeout=self.timeout, - allow_redirects=self.allow_redirects, - ) + with self.connector.connect() as session: + response = session.send( + prepared_request, + timeout=self.timeout, + allow_redirects=self.allow_redirects, + ) + self._write_request_duration_log( endpoint=self.path, response=response, @@ -387,8 +407,8 @@ def build_prepared_request( A :class:`requests.PreparedRequest` object. """ request = requests.Request(*args, **kwargs) - self.requests_session.auth = self.authenticator - return self.requests_session.prepare_request(request) + with self.connector.connect(authenticate=True) as session: + return session.prepare_request(request) def prepare_request( self, @@ -771,6 +791,7 @@ def __init__( path: str | None = None, *, http_method: str | None = None, + connector: HTTPConnector | None = None, ) -> None: """Initialize the REST stream. @@ -779,9 +800,17 @@ def __init__( schema: JSON schema for records in this stream. name: Name of this stream. path: URL path for this entity stream. - http_method: HTTP method to use for requests - """ - super().__init__(tap, name, schema, path, http_method=http_method) + http_method: HTTP method to use for requests. + connector: Connector to use for HTTP requests. + """ + super().__init__( + tap, + name, + schema, + path, + http_method=http_method, + connector=connector, + ) self._compiled_jsonpath = None self._next_page_token_compiled_jsonpath = None diff --git a/singer_sdk/streams/sql.py b/singer_sdk/streams/sql.py index 257843319..f7c904409 100644 --- a/singer_sdk/streams/sql.py +++ b/singer_sdk/streams/sql.py @@ -208,7 +208,7 @@ def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]: # processed. query = query.limit(self.ABORT_AT_RECORD_COUNT + 1) - with self.connector._connect() as conn: # noqa: SLF001 + with self.connector.connect() as conn: for record in conn.execute(query).mappings(): transformed_record = self.post_process(dict(record)) if transformed_record is None: diff --git a/tests/core/connectors/__init__.py b/tests/core/connectors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/connectors/test_http_connector.py b/tests/core/connectors/test_http_connector.py new file mode 100644 index 000000000..2c62ec94a --- /dev/null +++ b/tests/core/connectors/test_http_connector.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import json +import typing as t + +import requests +import werkzeug +from requests.adapters import BaseAdapter + +from singer_sdk.authenticators import HeaderAuth +from singer_sdk.connectors import HTTPConnector + +if t.TYPE_CHECKING: + from pytest_httpserver import HTTPServer + + +class MockAdapter(BaseAdapter): + def send( + self, + request: requests.PreparedRequest, + stream: bool = False, # noqa: FBT002 + timeout: float | tuple[float, float] | tuple[float, None] | None = None, + verify: bool | str = True, # noqa: FBT002 + cert: bytes | str | tuple[bytes | str, bytes | str] | None = None, + proxies: t.Mapping[str, str] | None = None, + ) -> requests.Response: + """Send a request.""" + response = requests.Response() + data = { + "url": request.url, + "headers": dict(request.headers), + "method": request.method, + "body": request.body, + "stream": stream, + "timeout": timeout, + "verify": verify, + "cert": cert, + "proxies": proxies, + } + response.status_code = 200 + response._content = json.dumps(data).encode("utf-8") + return response + + def close(self) -> None: + pass + + +class HeaderAuthConnector(HTTPConnector): + def get_authenticator(self) -> HeaderAuth: + return HeaderAuth("Bearer", self.config["token"]) + + +def test_base_connector(httpserver: HTTPServer): + connector = HTTPConnector({}) + + httpserver.expect_request("").respond_with_json({"foo": "bar"}) + url = httpserver.url_for("/") + + response = connector.request("GET", url) + data = response.json() + assert data["foo"] == "bar" + + +def test_auth(httpserver: HTTPServer): + connector = HeaderAuthConnector({"token": "s3cr3t"}) + + def _handler(request: werkzeug.Request) -> werkzeug.Response: + return werkzeug.Response( + json.dumps( + { + "headers": dict(request.headers), + "url": request.url, + }, + ), + status=200, + mimetype="application/json", + ) + + httpserver.expect_request("").respond_with_handler(_handler) + url = httpserver.url_for("/") + + response = connector.request("GET", url) + data = response.json() + assert data["headers"]["Authorization"] == "Bearer s3cr3t" + + response = connector.request("GET", url, authenticate=False) + data = response.json() + assert "Authorization" not in data["headers"] + + +def test_custom_adapters(): + class MyConnector(HTTPConnector): + @property + def adapters(self) -> dict[str, BaseAdapter]: + return { + "https://test": MockAdapter(), + } + + connector = MyConnector({}) + response = connector.request("GET", "https://test") + data = response.json() + + assert data["url"] == "https://test/" + assert data["headers"] + assert data["method"] == "GET" diff --git a/tests/core/test_connector_sql.py b/tests/core/connectors/test_sql_connector.py similarity index 97% rename from tests/core/test_connector_sql.py rename to tests/core/connectors/test_sql_connector.py index 0c7fa51a7..d8c8ebfc3 100644 --- a/tests/core/test_connector_sql.py +++ b/tests/core/connectors/test_sql_connector.py @@ -174,37 +174,47 @@ def test_deprecated_functions_warn(self, connector: SQLConnector): connector.create_sqlalchemy_connection() with pytest.deprecated_call(): _ = connector.connection + with pytest.deprecated_call(), connector._connect() as _: + pass - def test_connect_calls_engine(self, connector): + def test_connect_calls_engine(self, connector: SQLConnector): with ( - mock.patch.object(SQLConnector, "_engine") as mock_engine, - connector._connect() as _, + mock.patch.object( + SQLConnector, + "_engine", + ) as mock_engine, + connector.connect() as _, ): mock_engine.connect.assert_called_once() - def test_connect_calls_connect(self, connector): + def test_connect_calls_connect(self, connector: SQLConnector): attached_engine = connector._engine with ( - mock.patch.object(attached_engine, "connect") as mock_conn, - connector._connect() as _, + mock.patch.object( + attached_engine, + "connect", + ) as mock_conn, + connector.connect() as _, ): mock_conn.assert_called_once() - def test_connect_raises_on_operational_failure(self, connector): + def test_connect_raises_on_operational_failure(self, connector: SQLConnector): with ( - pytest.raises(sa.exc.OperationalError) as _, - connector._connect() as conn, + pytest.raises( + sa.exc.OperationalError, + ) as _, + connector.connect() as conn, ): conn.execute(sa.text("SELECT * FROM fake_table")) - def test_rename_column_uses_connect_correctly(self, connector): + def test_rename_column_uses_connect_correctly(self, connector: SQLConnector): attached_engine = connector._engine # Ends up using the attached engine with mock.patch.object(attached_engine, "connect") as mock_conn: connector.rename_column("fake_table", "old_name", "new_name") mock_conn.assert_called_once() # Uses the _connect method - with mock.patch.object(connector, "_connect") as mock_connect_method: + with mock.patch.object(connector, "connect") as mock_connect_method: connector.rename_column("fake_table", "old_name", "new_name") mock_connect_method.assert_called_once() diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 7d1b869f2..55480e65e 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -737,3 +737,9 @@ def discover_streams(self): assert all( tap.streams[stream].selected is selection[stream] for stream in selection ) + + +def test_deprecations(tap: SimpleTestTap): + stream = RestTestStream(tap=tap) + with pytest.deprecated_call(): + _ = stream.requests_session diff --git a/tests/samples/conftest.py b/tests/samples/conftest.py index 1202dc688..d8036cc01 100644 --- a/tests/samples/conftest.py +++ b/tests/samples/conftest.py @@ -21,7 +21,7 @@ def csv_config(outdir: str) -> dict: @pytest.fixture def sqlite_sample_db(sqlite_connector: SQLiteConnector): """Return a path to a newly constructed sample DB.""" - with sqlite_connector._connect() as conn, conn.begin(): + with sqlite_connector.connect() as conn, conn.begin(): for t in range(3): conn.execute(sa.text(f"DROP TABLE IF EXISTS t{t}")) conn.execute( diff --git a/uv.lock b/uv.lock index daf137e7b..c4811476f 100644 --- a/uv.lock +++ b/uv.lock @@ -1945,6 +1945,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/9b/952c70bd1fae9baa58077272e7f191f377c86d812263c21b361195e125e6/pytest_codspeed-3.2.0-py3-none-any.whl", hash = "sha256:54b5c2e986d6a28e7b0af11d610ea57bd5531cec8326abe486f1b55b09d91c39", size = 15007 }, ] +[[package]] +name = "pytest-httpserver" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/d8/def15ba33bd696dd72dd4562a5287c0cba4d18a591eeb82e0b08ab385afc/pytest_httpserver-1.1.3.tar.gz", hash = "sha256:af819d6b533f84b4680b9416a5b3f67f1df3701f1da54924afd4d6e4ba5917ec", size = 68870 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/d2/dfc2f25f3905921c2743c300a48d9494d29032f1389fc142e718d6978fb2/pytest_httpserver-1.1.3-py3-none-any.whl", hash = "sha256:5f84757810233e19e2bb5287f3826a71c97a3740abe3a363af9155c0f82fdbb9", size = 21000 }, +] + [[package]] name = "pytest-snapshot" version = "0.9.0" @@ -2479,6 +2491,7 @@ benchmark = [ { name = "pytest" }, { name = "pytest-benchmark" }, { name = "pytest-codspeed" }, + { name = "pytest-httpserver" }, { name = "pytest-snapshot" }, { name = "pytest-subtests" }, { name = "pytz" }, @@ -2501,6 +2514,7 @@ dev = [ { name = "myst-parser", version = "4.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "pytest" }, { name = "pytest-benchmark" }, + { name = "pytest-httpserver" }, { name = "pytest-snapshot" }, { name = "pytest-subtests" }, { name = "pytz" }, @@ -2544,6 +2558,7 @@ testing = [ { name = "moto" }, { name = "pytest" }, { name = "pytest-benchmark" }, + { name = "pytest-httpserver" }, { name = "pytest-snapshot" }, { name = "pytest-subtests" }, { name = "pytz" }, @@ -2609,6 +2624,7 @@ benchmark = [ { name = "pytest", specifier = ">=7.2.1" }, { name = "pytest-benchmark", specifier = ">=4.0.0" }, { name = "pytest-codspeed", specifier = ">=2.2.0" }, + { name = "pytest-httpserver", specifier = ">=1.1.3" }, { name = "pytest-snapshot", specifier = ">=0.9.0" }, { name = "pytest-subtests", specifier = ">=0.13.1" }, { name = "pytz", specifier = ">=2022.2.1" }, @@ -2630,6 +2646,7 @@ dev = [ { name = "myst-parser", specifier = ">=3" }, { name = "pytest", specifier = ">=7.2.1" }, { name = "pytest-benchmark", specifier = ">=4.0.0" }, + { name = "pytest-httpserver", specifier = ">=1.1.3" }, { name = "pytest-snapshot", specifier = ">=0.9.0" }, { name = "pytest-subtests", specifier = ">=0.13.1" }, { name = "pytz", specifier = ">=2022.2.1" }, @@ -2667,6 +2684,7 @@ testing = [ { name = "moto", specifier = ">=5.0.14" }, { name = "pytest", specifier = ">=7.2.1" }, { name = "pytest-benchmark", specifier = ">=4.0.0" }, + { name = "pytest-httpserver", specifier = ">=1.1.3" }, { name = "pytest-snapshot", specifier = ">=0.9.0" }, { name = "pytest-subtests", specifier = ">=0.13.1" }, { name = "pytz", specifier = ">=2022.2.1" },