diff --git a/.changes/unreleased/Features-20241031-160641.yaml b/.changes/unreleased/Features-20241031-160641.yaml new file mode 100644 index 000000000..758a17ca9 --- /dev/null +++ b/.changes/unreleased/Features-20241031-160641.yaml @@ -0,0 +1,6 @@ +kind: Features +body: allows to use prj. vars in python models +time: 2024-10-31T16:06:41.214334214-03:00 +custom: + Author: devmessias + Issue: 5617 10914 diff --git a/dbt/include/global_project/macros/python_model/python.sql b/dbt/include/global_project/macros/python_model/python.sql index d658ff185..1e8c2dc22 100644 --- a/dbt/include/global_project/macros/python_model/python.sql +++ b/dbt/include/global_project/macros/python_model/python.sql @@ -58,6 +58,21 @@ def source(*args, dbt_load_df_function): config_dict = {{ config_dict }} {% endmacro %} + +{% macro build_var_dict(model, context) %} + {%- set var_dict = {} -%} + {% set var_dbt_used = zip(model.config.var_keys_used, model.config.var_keys_defaults) | list %} + {%- for key, default in var_dbt_used -%} + {# weird type testing with enum, would be much easier to write this logic in Python! #} + {%- if key == "language" -%} + {%- set value = "python" -%} + {%- endif -%} + {%- set value = context.var(key, default) -%} + {%- do var_dict.update({key: value}) -%} + {%- endfor -%} +var_dict = {{ var_dict }} +{% endmacro %} + {% macro py_script_postfix(model) %} # This part is user provided model code # you will need to copy the next section to run the code @@ -67,6 +82,7 @@ config_dict = {{ config_dict }} {{ build_ref_function(model ) }} {{ build_source_function(model ) }} {{ build_config_dict(model) }} +{{ build_var_dict(model, context) }} class config: def __init__(self, *args, **kwargs): @@ -76,6 +92,15 @@ class config: def get(key, default=None): return config_dict.get(key, default) +class DbtVar: + def __init__(self, *args, **kwargs): + pass + + @staticmethod + def get(key, default=None): + return var_dict.get(key, default) + + class this: """dbt.this() or dbt.this.identifier""" database = "{{ this.database }}" @@ -91,6 +116,7 @@ class dbtObj: self.source = lambda *args: source(*args, dbt_load_df_function=load_df_function) self.ref = lambda *args, **kwargs: ref(*args, **kwargs, dbt_load_df_function=load_df_function) self.config = config + self.var = DbtVar self.this = this() self.is_incremental = {{ is_incremental() }}