Skip to content

Commit

Permalink
replace MacroResolver with MacroClient
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Dec 7, 2023
1 parent bb12375 commit 4f498f9
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 84 deletions.
63 changes: 19 additions & 44 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ConstraintType,
ModelLevelConstraint,
)
from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt.adapters.contracts.macros import MacroClient

import agate
import pytz
Expand Down Expand Up @@ -62,7 +62,7 @@
table_from_rows,
Integer,
)
from dbt.common.clients.jinja import CallableMacroGenerator, MacroProtocol
from dbt.common.clients.jinja import CallableMacroGenerator
from dbt.common.events.functions import fire_event, warn_or_error
from dbt.adapters.events.types import (
CacheMiss,
Expand All @@ -80,7 +80,6 @@
Connection,
AdapterResponse,
BaseConnectionManager,
AdapterRequiredConfig,
)
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import (
Expand Down Expand Up @@ -262,35 +261,20 @@ def __init__(self, config, mp_context: SpawnContext) -> None:
self.config = config
self.cache = RelationsCache(log_cache_events=config.log_cache_events)
self.connections = self.ConnectionManager(config, mp_context)
self._macro_resolver: Optional[MacroResolverProtocol] = None
self._macro_context_generator: Optional[
Callable[
[MacroProtocol, AdapterRequiredConfig, MacroResolverProtocol, Optional[str]],
Dict[str, Any],
]
] = None
self._macro_client: Optional[MacroClient] = None

###
# Methods to set / access a macro resolver
# Methods to set / access a macro client
###
def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None:
self._macro_resolver = macro_resolver
def set_macro_client(self, macro_client: MacroClient) -> None:
self._macro_client = macro_client

def get_macro_resolver(self) -> Optional[MacroResolverProtocol]:
return self._macro_resolver
def get_macro_client(self) -> Optional[MacroClient]:
return self._macro_client

def clear_macro_resolver(self) -> None:
if self._macro_resolver is not None:
self._macro_resolver = None

def set_macro_context_generator(
self,
macro_context_generator: Callable[
[MacroProtocol, AdapterRequiredConfig, MacroResolverProtocol, Optional[str]],
Dict[str, Any],
],
) -> None:
self._macro_context_generator = macro_context_generator
def clear_macro_client(self) -> None:
if self._macro_client is not None:
self._macro_client = None

Check warning on line 277 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L276-L277

Added lines #L276 - L277 were not covered by tests

###
# Methods that pass through to the connection manager
Expand Down Expand Up @@ -1052,7 +1036,6 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[
def execute_macro(
self,
macro_name: str,
macro_resolver: Optional[MacroResolverProtocol] = None,
project: Optional[str] = None,
context_override: Optional[Dict[str, Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -1075,14 +1058,12 @@ def execute_macro(
if context_override is None:
context_override = {}

resolver = macro_resolver or self._macro_resolver
if resolver is None:
raise DbtInternalError("Macro resolver was None when calling execute_macro!")

if self._macro_context_generator is None:
raise DbtInternalError("Macro context generator was None when calling execute_macro!")
if self._macro_client is None:
raise DbtInternalError("Macro client was None when calling execute_macro!")

Check warning on line 1062 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L1062

Added line #L1062 was not covered by tests

macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project)
macro = self._macro_client.find_macro_by_name(
macro_name, self.config.project_name, project
)
if macro is None:
if project is None:
package_name = "any package"
Expand All @@ -1095,7 +1076,7 @@ def execute_macro(
)
)

macro_context = self._macro_context_generator(macro, self.config, resolver, project)
macro_context = self._macro_client.generate_context_for_macro(self.config, macro, project)
macro_context.update(context_override)

macro_function = CallableMacroGenerator(macro, macro_context)
Expand Down Expand Up @@ -1243,7 +1224,6 @@ def calculate_freshness(
source: BaseRelation,
loaded_at_field: str,
filter: Optional[str],
macro_resolver: Optional[MacroResolverProtocol] = None,
) -> Tuple[Optional[AdapterResponse], FreshnessResponse]:
"""Calculate the freshness of sources in dbt, and return it"""
kwargs: Dict[str, Any] = {
Expand All @@ -1259,9 +1239,7 @@ def calculate_freshness(
AttrDict, # current: contains AdapterResponse + agate.Table
agate.Table, # previous: just table
]
result = self.execute_macro(
FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver
)
result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs)

Check warning on line 1242 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L1242

Added line #L1242 was not covered by tests
if isinstance(result, agate.Table):
warn_or_error(CollectFreshnessReturnSignature())
adapter_response = None
Expand Down Expand Up @@ -1291,15 +1269,12 @@ def calculate_freshness(
def calculate_freshness_from_metadata(
self,
source: BaseRelation,
macro_resolver: Optional[MacroResolverProtocol] = None,
) -> Tuple[Optional[AdapterResponse], FreshnessResponse]:
kwargs: Dict[str, Any] = {
"information_schema": source.information_schema_only(),
"relations": [source],
}
result = self.execute_macro(
GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver
)
result = self.execute_macro(GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs)

Check warning on line 1277 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L1277

Added line #L1277 was not covered by tests
adapter_response, table = result.response, result.table # type: ignore[attr-defined]

try:
Expand Down
9 changes: 7 additions & 2 deletions core/dbt/adapters/contracts/macros.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Optional
from typing import Optional, Dict, Any
from typing_extensions import Protocol

from dbt.common.clients.jinja import MacroProtocol


class MacroResolverProtocol(Protocol):
class MacroClient(Protocol):
def find_macro_by_name(
self, name: str, root_project_name: str, package: Optional[str]
) -> Optional[MacroProtocol]:
raise NotImplementedError("find_macro_by_name not implemented")

def generate_context_for_macro(
self, config, macro: MacroProtocol, package: Optional[str]
) -> Dict[str, Any]:
raise NotImplementedError("generate_context_for_macro not implemented")

Check warning on line 16 in core/dbt/adapters/contracts/macros.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/contracts/macros.py#L16

Added line #L16 was not covered by tests
20 changes: 4 additions & 16 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@
TypeVar,
Tuple,
Any,
Callable,
Dict,
)
from typing_extensions import Protocol

import agate

from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt.adapters.contracts.macros import MacroClient
from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig
from dbt.common.contracts.config.base import BaseConfig
from dbt.common.clients.jinja import MacroProtocol
from dbt.contracts.graph.manifest import Manifest


Expand Down Expand Up @@ -80,22 +77,13 @@ class AdapterProtocol( # type: ignore[misc]
def __init__(self, config: AdapterRequiredConfig) -> None:
...

def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None:
def set_macro_client(self, macro_resolver: MacroClient) -> None:
...

def get_macro_resolver(self) -> Optional[MacroResolverProtocol]:
def get_macro_client(self) -> Optional[MacroClient]:
...

def clear_macro_resolver(self) -> None:
...

def set_macro_context_generator(
self,
macro_context_generator: Callable[
[MacroProtocol, AdapterRequiredConfig, MacroResolverProtocol, Optional[str]],
Dict[str, Any],
],
) -> None:
def clear_macro_client(self) -> None:
...

@classmethod
Expand Down
5 changes: 1 addition & 4 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dbt.tracking
from dbt.common.invocation import reset_invocation_id
from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management, register_adapter, get_adapter
from dbt.adapters.factory import adapter_management, register_adapter
from dbt.flags import set_flags, get_flag_dict
from dbt.cli.exceptions import (
ExceptionExit,
Expand All @@ -10,7 +10,6 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.context.providers import generate_runtime_macro_context

from dbt.common.events.base_types import EventLevel
from dbt.common.events.functions import (
Expand Down Expand Up @@ -275,8 +274,6 @@ def wrapper(*args, **kwargs):

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config)
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)

# a manifest has already been set on the context, so don't overwrite it
if ctx.obj.get("manifest") is None:
Expand Down
16 changes: 16 additions & 0 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from dbt.context.macros import MacroNamespaceBuilder, MacroNamespace
from dbt.context.manifest import ManifestContext
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.adapters.contracts.macros import MacroClient
from dbt.contracts.graph.manifest import Manifest, Disabled
from dbt.contracts.graph.nodes import (
Macro,
Expand Down Expand Up @@ -1530,6 +1531,21 @@ def generate_runtime_macro_context(
return ctx.to_dict()


class ManifestMacroClient(MacroClient):
def __init__(self, manifest: Manifest) -> None:
self._manifest = manifest

def find_macro_by_name(
self, name: str, root_project_name: str, package: Optional[str]
) -> Optional[MacroProtocol]:
return self._manifest.find_macro_by_name(name, root_project_name, package)

def generate_context_for_macro(
self, config: RuntimeConfig, macro: MacroProtocol, package: Optional[str]
) -> Dict[str, Any]:
return generate_runtime_macro_context(macro, config, self._manifest, package)


class ExposureRefResolver(BaseResolver):
def __call__(self, *args, **kwargs) -> str:
package = None
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from dbt.context.docs import generate_runtime_docs_context
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from dbt.context.configured import generate_macro_context
from dbt.context.providers import ParseProvider
from dbt.context.providers import ParseProvider, ManifestMacroClient
from dbt.contracts.files import FileHash, ParseFileType, SchemaSourceFile
from dbt.parser.read_files import (
ReadFilesFromFileSystem,
Expand Down Expand Up @@ -286,7 +286,7 @@ def get_full_manifest(
# the config and adapter may be persistent.
if reset:
config.clear_dependencies()
adapter.clear_macro_resolver()
adapter.clear_macro_client()

Check warning on line 289 in core/dbt/parser/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/manifest.py#L289

Added line #L289 was not covered by tests
macro_hook = adapter.connections.set_query_header

flags = get_flags()
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def build_manifest_state_check(self):

def save_macros_to_adapter(self, adapter):
macro_manifest = MacroManifest(self.manifest.macros)
adapter.set_macro_resolver(macro_manifest)
adapter.set_macro_client(ManifestMacroClient(macro_manifest))
# This executes the callable macro_hook and sets the
# query headers
self.macro_hook(macro_manifest)
Expand Down
4 changes: 1 addition & 3 deletions core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def _run_unsafe(self, package_name, macro_name) -> agate.Table:

with adapter.connection_named("macro_{}".format(macro_name)):
adapter.clear_transaction()
res = adapter.execute_macro(
macro_name, project=package_name, kwargs=macro_kwargs, macro_resolver=self.manifest
)
res = adapter.execute_macro(macro_name, project=package_name, kwargs=macro_kwargs)

return res

Expand Down
1 change: 0 additions & 1 deletion core/dbt/task/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def execute(self, compiled_node, manifest):
model_context = generate_runtime_model_context(compiled_node, self.config, manifest)
compiled_node.compiled_code = self.adapter.execute_macro(
macro_name="get_show_sql",
macro_resolver=manifest,
context_override=model_context,
kwargs={
"compiled_code": model_context["compiled_code"],
Expand Down
9 changes: 4 additions & 5 deletions core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dbt.parser.manifest import ManifestLoader
from dbt.common.exceptions import CompilationError, DbtDatabaseError
from dbt.context.providers import generate_runtime_macro_context
from dbt.context.providers import ManifestMacroClient
import dbt.flags as flags
from dbt.config.runtime import RuntimeConfig
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters, get_adapter_by_type
Expand Down Expand Up @@ -297,8 +297,7 @@ def adapter(
base_macros_only=True,
)

adapter.set_macro_resolver(manifest)
adapter.set_macro_context_generator(generate_runtime_macro_context)
adapter.set_macro_client(ManifestMacroClient(manifest))
yield adapter
adapter.cleanup_connections()
reset_adapters()
Expand Down Expand Up @@ -459,13 +458,13 @@ def create_test_schema(self, schema_name=None):

# Drop the unique test schema, usually called in test cleanup
def drop_test_schema(self):
if self.adapter.get_macro_resolver() is None:
if self.adapter.get_macro_client() is None:
manifest = ManifestLoader.load_macros(
self.adapter.config,
self.adapter.connections.set_query_header,
base_macros_only=True,
)
self.adapter.set_macro_resolver(manifest)
self.adapter.set_macro_client(ManifestMacroClient(manifest))

with get_connection(self.adapter):
for schema_name in self.created_schemas:
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import ManifestStateCheck
from dbt.common.clients import agate_helper
from dbt.context.providers import ManifestMacroClient
from dbt.exceptions import DbtConfigError
from dbt.common.exceptions import DbtValidationError
from dbt.context.providers import generate_runtime_macro_context
from psycopg2 import extensions as psycopg2_extensions
from psycopg2 import DatabaseError

Expand Down Expand Up @@ -429,11 +429,10 @@ def _mock_state_check(self):

self.psycopg2.connect.return_value = self.handle
self.adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config))
self.adapter.set_macro_context_generator(generate_runtime_macro_context)
self.adapter.connections.query_header = MacroQueryStringSetter(
self.config, self.adapter.get_macro_resolver()
)

macro_manifest = load_internal_manifest_macros(self.config)
self.adapter.set_macro_client(ManifestMacroClient(macro_manifest))
self.adapter.connections.query_header = MacroQueryStringSetter(self.config, macro_manifest)

self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")
self.mock_query_header_add = self.qh_patch.start()
Expand Down

0 comments on commit 4f498f9

Please sign in to comment.