Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added parametric expressions for headers in the rest_api #2262

Open
wants to merge 1 commit into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _create_request(
path_or_url: str,
method: HTTPMethod,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -110,10 +111,12 @@ def _create_request(
else:
url = join_url(self.base_url, path_or_url)

request_headers = (self.headers or {}) | (headers or {})

return Request(
method=method,
url=url,
headers=self.headers,
headers=request_headers,
params=params,
json=json,
auth=auth or self.auth,
Expand Down Expand Up @@ -144,6 +147,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) ->
path_or_url=path,
method=method,
params=kwargs.pop("params", None),
headers=kwargs.pop("headers", None),
json=kwargs.pop("json", None),
auth=kwargs.pop("auth", None),
hooks=kwargs.pop("hooks", None),
Expand All @@ -161,6 +165,7 @@ def paginate(
path: str = "",
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
Expand Down Expand Up @@ -213,7 +218,13 @@ def paginate(
hooks["response"] = [raise_for_status]

request = self._create_request(
path_or_url=path, method=method, params=params, json=json, auth=auth, hooks=hooks
path_or_url=path,
headers=headers,
method=method,
params=params,
json=json,
auth=auth,
hooks=hooks,
)

if paginator:
Expand Down
46 changes: 32 additions & 14 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic API Source"""

from copy import deepcopy
from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union
import graphlib
Expand Down Expand Up @@ -70,7 +71,11 @@ def rest_api(
) -> List[DltResource]:
"""Creates and configures a REST API source with default settings"""
return rest_api_resources(
{"client": client, "resources": resources, "resource_defaults": resource_defaults}
{
"client": client,
"resources": resources,
"resource_defaults": resource_defaults,
}
)


Expand Down Expand Up @@ -242,6 +247,7 @@ def create_resources(
endpoint_config = cast(Endpoint, endpoint_resource["endpoint"])
request_params = endpoint_config.get("params", {})
request_json = endpoint_config.get("json", None)
request_headers = endpoint_config.get("headers")
paginator = create_paginator(endpoint_config.get("paginator"))
processing_steps = endpoint_resource.pop("processing_steps", [])

Expand Down Expand Up @@ -288,6 +294,7 @@ def process(
def paginate_resource(
method: HTTPMethodBasic,
path: str,
headers: Dict[str, Any],
params: Dict[str, Any],
json: Optional[Dict[str, Any]],
paginator: Optional[BasePaginator],
Expand Down Expand Up @@ -323,6 +330,7 @@ def paginate_resource(
yield from client.paginate(
method=method,
path=path,
headers=headers,
params=params,
json=json,
paginator=paginator,
Expand All @@ -336,6 +344,7 @@ def paginate_resource(
)(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
headers=request_headers,
params=request_params,
json=request_json,
paginator=paginator,
Expand All @@ -355,6 +364,7 @@ def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
request_headers: Optional[Dict[str, Any]],
params: Dict[str, Any],
json: Optional[Dict[str, Any]],
paginator: Optional[BasePaginator],
Expand All @@ -378,23 +388,29 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, expanded_params, updated_json, parent_record = (
process_parent_data_item(
path=path,
item=item,
params=params,
request_json=json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
incremental=incremental_object,
incremental_value_convert=incremental_cursor_transform,
)
(
formatted_path,
expanded_params,
updated_json,
updated_headers,
parent_record,
) = process_parent_data_item(
path=path,
item=item,
params=params,
request_headers=request_headers,
request_json=json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
incremental=incremental_object,
incremental_value_convert=incremental_cursor_transform,
)

for child_page in client.paginate(
method=method,
path=formatted_path,
params=expanded_params,
headers=updated_headers,
json=updated_json,
paginator=paginator,
data_selector=data_selector,
Expand All @@ -413,6 +429,7 @@ def paginate_dependent_resource(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
params=base_params,
request_headers=request_headers,
json=request_json,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
Expand Down Expand Up @@ -456,7 +473,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS)
if (
isinstance(
auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials)
auth_config,
(APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials),
)
or has_sensitive_key
):
Expand Down Expand Up @@ -503,7 +521,7 @@ def identity_func(x: Any) -> Any:


def _validate_param_type(
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]],
) -> None:
for _, value in request_params.items():
if isinstance(value, dict) and value.get("type") not in PARAM_TYPES:
Expand Down
30 changes: 24 additions & 6 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def setup_incremental_object(


def parse_convert_or_deprecated_transform(
config: Union[IncrementalConfig, Dict[str, Any]]
config: Union[IncrementalConfig, Dict[str, Any]],
) -> Optional[Callable[..., Any]]:
convert = config.get("convert", None)
deprecated_transform = config.get("transform", None)
Expand Down Expand Up @@ -317,15 +317,20 @@ def build_resource_dependency_graph(
endpoint_resource["endpoint"]["path"], available_contexts
)

# Find all expressions in params and json, but error if any of them is not in available_contexts
# Find all expressions in params, json, or header, but error if any of them is not in available_contexts
params_expressions = _find_expressions(endpoint_resource["endpoint"].get("params", {}))
_raise_if_any_not_in(params_expressions, available_contexts, message="params")

json_expressions = _find_expressions(endpoint_resource["endpoint"].get("json", {}))
_raise_if_any_not_in(json_expressions, available_contexts, message="json")

headers_expressions = _find_expressions(endpoint_resource["endpoint"].get("headers", {}))
_raise_if_any_not_in(headers_expressions, available_contexts, message="headers")

resolved_params += _expressions_to_resolved_params(
_filter_resource_expressions(path_expressions | params_expressions | json_expressions)
_filter_resource_expressions(
path_expressions | params_expressions | json_expressions | headers_expressions
)
)

# set of resources in resolved params
Expand Down Expand Up @@ -723,11 +728,12 @@ def process_parent_data_item(
item: Dict[str, Any],
resolved_params: List[ResolvedParam],
params: Optional[Dict[str, Any]] = None,
request_headers: Optional[Dict[str, Any]] = None,
request_json: Optional[Dict[str, Any]] = None,
include_from_parent: Optional[List[str]] = None,
incremental: Optional[Incremental[Any]] = None,
incremental_value_convert: Optional[Callable[..., Any]] = None,
) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
params_values = collect_resolved_values(
item, resolved_params, incremental, incremental_value_convert
)
Expand All @@ -737,10 +743,20 @@ def process_parent_data_item(
None if request_json is None else expand_placeholders(request_json, params_values)
)

expanded_headers = (
None if request_headers is None else expand_placeholders(request_headers, params_values)
)

parent_resource_name = resolved_params[0].resolve_config["resource"]
parent_record = build_parent_record(item, parent_resource_name, include_from_parent)

return expanded_path, expanded_params, expanded_json, parent_record
return (
expanded_path,
expanded_params,
expanded_json,
expanded_headers,
parent_record,
)


def convert_incremental_values(
Expand Down Expand Up @@ -819,7 +835,9 @@ def expand_placeholders(obj: Any, placeholders: Dict[str, Any]) -> Any:


def build_parent_record(
item: Dict[str, Any], parent_resource_name: str, include_from_parent: Optional[List[str]]
item: Dict[str, Any],
parent_resource_name: str,
include_from_parent: Optional[List[str]],
) -> Dict[str, Any]:
"""
Builds a dictionary of the `include_from_parent` fields from the parent,
Expand Down
1 change: 1 addition & 0 deletions dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class Endpoint(TypedDict, total=False):
response_actions: Optional[List[ResponseAction]]
incremental: Optional[IncrementalConfig]
auth: Optional[AuthConfig]
headers: Optional[Dict[str, Any]]


class ProcessingSteps(TypedDict):
Expand Down
Loading
Loading