diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 0af6b4fb255..fb101c003e6 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -28,7 +28,7 @@ ConstraintType, ModelLevelConstraint, ) -from dbt.adapters.contracts.macros import MacroResolverProtocol +from dbt.adapters.contracts.macros import MacroClient import agate import pytz @@ -62,7 +62,7 @@ table_from_rows, Integer, ) -from dbt.common.clients.jinja import CallableMacroGenerator, MacroProtocol +from dbt.common.clients.jinja import CallableMacroGenerator from dbt.common.events.functions import fire_event, warn_or_error from dbt.adapters.events.types import ( CacheMiss, @@ -80,7 +80,6 @@ Connection, AdapterResponse, BaseConnectionManager, - AdapterRequiredConfig, ) from dbt.adapters.base.meta import AdapterMeta, available from dbt.adapters.base.relation import ( @@ -262,35 +261,20 @@ def __init__(self, config, mp_context: SpawnContext) -> None: self.config = config self.cache = RelationsCache(log_cache_events=config.log_cache_events) self.connections = self.ConnectionManager(config, mp_context) - self._macro_resolver: Optional[MacroResolverProtocol] = None - self._macro_context_generator: Optional[ - Callable[ - [MacroProtocol, AdapterRequiredConfig, MacroResolverProtocol, Optional[str]], - Dict[str, Any], - ] - ] = None + self._macro_client: Optional[MacroClient] = None ### - # Methods to set / access a macro resolver + # Methods to set / access a macro client ### - def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None: - self._macro_resolver = macro_resolver + def set_macro_client(self, macro_client: MacroClient) -> None: + self._macro_client = macro_client - def get_macro_resolver(self) -> Optional[MacroResolverProtocol]: - return self._macro_resolver + def get_macro_client(self) -> Optional[MacroClient]: + return self._macro_client - def clear_macro_resolver(self) -> None: - if self._macro_resolver is not None: - self._macro_resolver = None - - def set_macro_context_generator( - self, - macro_context_generator: Callable[ - [MacroProtocol, AdapterRequiredConfig, MacroResolverProtocol, Optional[str]], - Dict[str, Any], - ], - ) -> None: - self._macro_context_generator = macro_context_generator + def clear_macro_client(self) -> None: + if self._macro_client is not None: + self._macro_client = None ### # Methods that pass through to the connection manager @@ -1052,7 +1036,6 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[ def execute_macro( self, macro_name: str, - macro_resolver: Optional[MacroResolverProtocol] = None, project: Optional[str] = None, context_override: Optional[Dict[str, Any]] = None, kwargs: Optional[Dict[str, Any]] = None, @@ -1075,14 +1058,12 @@ def execute_macro( if context_override is None: context_override = {} - resolver = macro_resolver or self._macro_resolver - if resolver is None: - raise DbtInternalError("Macro resolver was None when calling execute_macro!") - - if self._macro_context_generator is None: - raise DbtInternalError("Macro context generator was None when calling execute_macro!") + if self._macro_client is None: + raise DbtInternalError("Macro client was None when calling execute_macro!") - macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project) + macro = self._macro_client.find_macro_by_name( + macro_name, self.config.project_name, project + ) if macro is None: if project is None: package_name = "any package" @@ -1095,7 +1076,7 @@ def execute_macro( ) ) - macro_context = self._macro_context_generator(macro, self.config, resolver, project) + macro_context = self._macro_client.generate_context_for_macro(self.config, macro, project) macro_context.update(context_override) macro_function = CallableMacroGenerator(macro, macro_context) @@ -1243,7 +1224,6 @@ def calculate_freshness( source: BaseRelation, loaded_at_field: str, filter: Optional[str], - macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: """Calculate the freshness of sources in dbt, and return it""" kwargs: Dict[str, Any] = { @@ -1259,9 +1239,7 @@ def calculate_freshness( AttrDict, # current: contains AdapterResponse + agate.Table agate.Table, # previous: just table ] - result = self.execute_macro( - FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver - ) + result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs) if isinstance(result, agate.Table): warn_or_error(CollectFreshnessReturnSignature()) adapter_response = None @@ -1291,15 +1269,12 @@ def calculate_freshness( def calculate_freshness_from_metadata( self, source: BaseRelation, - macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: kwargs: Dict[str, Any] = { "information_schema": source.information_schema_only(), "relations": [source], } - result = self.execute_macro( - GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver - ) + result = self.execute_macro(GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs) adapter_response, table = result.response, result.table # type: ignore[attr-defined] try: diff --git a/core/dbt/adapters/contracts/macros.py b/core/dbt/adapters/contracts/macros.py index 151c9c44dde..7280a138444 100644 --- a/core/dbt/adapters/contracts/macros.py +++ b/core/dbt/adapters/contracts/macros.py @@ -1,11 +1,16 @@ -from typing import Optional +from typing import Optional, Dict, Any from typing_extensions import Protocol from dbt.common.clients.jinja import MacroProtocol -class MacroResolverProtocol(Protocol): +class MacroClient(Protocol): def find_macro_by_name( self, name: str, root_project_name: str, package: Optional[str] ) -> Optional[MacroProtocol]: raise NotImplementedError("find_macro_by_name not implemented") + + def generate_context_for_macro( + self, config, macro: MacroProtocol, package: Optional[str] + ) -> Dict[str, Any]: + raise NotImplementedError("generate_context_for_macro not implemented") diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 9d5e5b5ad8e..64a6d35e973 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -9,18 +9,15 @@ TypeVar, Tuple, Any, - Callable, - Dict, ) from typing_extensions import Protocol import agate from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse -from dbt.adapters.contracts.macros import MacroResolverProtocol +from dbt.adapters.contracts.macros import MacroClient from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.common.contracts.config.base import BaseConfig -from dbt.common.clients.jinja import MacroProtocol from dbt.contracts.graph.manifest import Manifest @@ -80,22 +77,13 @@ class AdapterProtocol( # type: ignore[misc] def __init__(self, config: AdapterRequiredConfig) -> None: ... - def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None: + def set_macro_client(self, macro_resolver: MacroClient) -> None: ... - def get_macro_resolver(self) -> Optional[MacroResolverProtocol]: + def get_macro_client(self) -> Optional[MacroClient]: ... - def clear_macro_resolver(self) -> None: - ... - - def set_macro_context_generator( - self, - macro_context_generator: Callable[ - [MacroProtocol, AdapterRequiredConfig, MacroResolverProtocol, Optional[str]], - Dict[str, Any], - ], - ) -> None: + def clear_macro_client(self) -> None: ... @classmethod diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index 1aff882b569..377ebcd066a 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -1,7 +1,7 @@ import dbt.tracking from dbt.common.invocation import reset_invocation_id from dbt.version import installed as installed_version -from dbt.adapters.factory import adapter_management, register_adapter, get_adapter +from dbt.adapters.factory import adapter_management, register_adapter from dbt.flags import set_flags, get_flag_dict from dbt.cli.exceptions import ( ExceptionExit, @@ -10,7 +10,6 @@ from dbt.cli.flags import Flags from dbt.config import RuntimeConfig from dbt.config.runtime import load_project, load_profile, UnsetProfile -from dbt.context.providers import generate_runtime_macro_context from dbt.common.events.base_types import EventLevel from dbt.common.events.functions import ( @@ -275,8 +274,6 @@ def wrapper(*args, **kwargs): runtime_config = ctx.obj["runtime_config"] register_adapter(runtime_config) - adapter = get_adapter(runtime_config) - adapter.set_macro_context_generator(generate_runtime_macro_context) # a manifest has already been set on the context, so don't overwrite it if ctx.obj.get("manifest") is None: diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 1f99665af7e..b38ff5b6bd2 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -29,6 +29,7 @@ from dbt.context.macros import MacroNamespaceBuilder, MacroNamespace from dbt.context.manifest import ManifestContext from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.contracts.macros import MacroClient from dbt.contracts.graph.manifest import Manifest, Disabled from dbt.contracts.graph.nodes import ( Macro, @@ -1530,6 +1531,21 @@ def generate_runtime_macro_context( return ctx.to_dict() +class ManifestMacroClient(MacroClient): + def __init__(self, manifest: Manifest) -> None: + self._manifest = manifest + + def find_macro_by_name( + self, name: str, root_project_name: str, package: Optional[str] + ) -> Optional[MacroProtocol]: + return self._manifest.find_macro_by_name(name, root_project_name, package) + + def generate_context_for_macro( + self, config: RuntimeConfig, macro: MacroProtocol, package: Optional[str] + ) -> Dict[str, Any]: + return generate_runtime_macro_context(macro, config, self._manifest, package) + + class ExposureRefResolver(BaseResolver): def __call__(self, *args, **kwargs) -> str: package = None diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index c952db063f4..1ae05b559b0 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -74,7 +74,7 @@ from dbt.context.docs import generate_runtime_docs_context from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace from dbt.context.configured import generate_macro_context -from dbt.context.providers import ParseProvider +from dbt.context.providers import ParseProvider, ManifestMacroClient from dbt.contracts.files import FileHash, ParseFileType, SchemaSourceFile from dbt.parser.read_files import ( ReadFilesFromFileSystem, @@ -286,7 +286,7 @@ def get_full_manifest( # the config and adapter may be persistent. if reset: config.clear_dependencies() - adapter.clear_macro_resolver() + adapter.clear_macro_client() macro_hook = adapter.connections.set_query_header flags = get_flags() @@ -1000,7 +1000,7 @@ def build_manifest_state_check(self): def save_macros_to_adapter(self, adapter): macro_manifest = MacroManifest(self.manifest.macros) - adapter.set_macro_resolver(macro_manifest) + adapter.set_macro_client(ManifestMacroClient(macro_manifest)) # This executes the callable macro_hook and sets the # query headers self.macro_hook(macro_manifest) diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index 379d5ec6ab8..1e79d791f0b 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -40,9 +40,7 @@ def _run_unsafe(self, package_name, macro_name) -> agate.Table: with adapter.connection_named("macro_{}".format(macro_name)): adapter.clear_transaction() - res = adapter.execute_macro( - macro_name, project=package_name, kwargs=macro_kwargs, macro_resolver=self.manifest - ) + res = adapter.execute_macro(macro_name, project=package_name, kwargs=macro_kwargs) return res diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index 961a36c6127..c6d47226507 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -27,7 +27,6 @@ def execute(self, compiled_node, manifest): model_context = generate_runtime_model_context(compiled_node, self.config, manifest) compiled_node.compiled_code = self.adapter.execute_macro( macro_name="get_show_sql", - macro_resolver=manifest, context_override=model_context, kwargs={ "compiled_code": model_context["compiled_code"], diff --git a/core/dbt/tests/fixtures/project.py b/core/dbt/tests/fixtures/project.py index fedc3962d46..518ccb62170 100644 --- a/core/dbt/tests/fixtures/project.py +++ b/core/dbt/tests/fixtures/project.py @@ -9,7 +9,7 @@ from dbt.parser.manifest import ManifestLoader from dbt.common.exceptions import CompilationError, DbtDatabaseError -from dbt.context.providers import generate_runtime_macro_context +from dbt.context.providers import ManifestMacroClient import dbt.flags as flags from dbt.config.runtime import RuntimeConfig from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters, get_adapter_by_type @@ -297,8 +297,7 @@ def adapter( base_macros_only=True, ) - adapter.set_macro_resolver(manifest) - adapter.set_macro_context_generator(generate_runtime_macro_context) + adapter.set_macro_client(ManifestMacroClient(manifest)) yield adapter adapter.cleanup_connections() reset_adapters() @@ -459,13 +458,13 @@ def create_test_schema(self, schema_name=None): # Drop the unique test schema, usually called in test cleanup def drop_test_schema(self): - if self.adapter.get_macro_resolver() is None: + if self.adapter.get_macro_client() is None: manifest = ManifestLoader.load_macros( self.adapter.config, self.adapter.connections.set_query_header, base_macros_only=True, ) - self.adapter.set_macro_resolver(manifest) + self.adapter.set_macro_client(ManifestMacroClient(manifest)) with get_connection(self.adapter): for schema_name in self.created_schemas: diff --git a/tests/unit/test_postgres_adapter.py b/tests/unit/test_postgres_adapter.py index c6221f288be..7d8798bf415 100644 --- a/tests/unit/test_postgres_adapter.py +++ b/tests/unit/test_postgres_adapter.py @@ -16,9 +16,9 @@ from dbt.contracts.files import FileHash from dbt.contracts.graph.manifest import ManifestStateCheck from dbt.common.clients import agate_helper +from dbt.context.providers import ManifestMacroClient from dbt.exceptions import DbtConfigError from dbt.common.exceptions import DbtValidationError -from dbt.context.providers import generate_runtime_macro_context from psycopg2 import extensions as psycopg2_extensions from psycopg2 import DatabaseError @@ -429,11 +429,10 @@ def _mock_state_check(self): self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config, self.mp_context) - self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config)) - self.adapter.set_macro_context_generator(generate_runtime_macro_context) - self.adapter.connections.query_header = MacroQueryStringSetter( - self.config, self.adapter.get_macro_resolver() - ) + + macro_manifest = load_internal_manifest_macros(self.config) + self.adapter.set_macro_client(ManifestMacroClient(macro_manifest)) + self.adapter.connections.query_header = MacroQueryStringSetter(self.config, macro_manifest) self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add") self.mock_query_header_add = self.qh_patch.start()