diff --git a/.changes/unreleased/Features-20240722-133729.yaml b/.changes/unreleased/Features-20240722-133729.yaml new file mode 100644 index 00000000000..ed99bcb6c1f --- /dev/null +++ b/.changes/unreleased/Features-20240722-133729.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support passing `args` as keyword argument for `run-operation` in programmatic invocations +time: 2024-07-22T13:37:29.285621-06:00 +custom: + Author: dbeatty10 + Issue: "10473" diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index ce0a42aae2e..edb1aa34ac0 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -47,9 +47,9 @@ def __init__( callbacks = [] self.callbacks = callbacks - def invoke(self, args: List[str], **kwargs) -> dbtRunnerResult: + def invoke(self, invocation_args: List[str], /, **kwargs) -> dbtRunnerResult: try: - dbt_ctx = cli.make_context(cli.name, args.copy()) + dbt_ctx = cli.make_context(cli.name, invocation_args.copy()) dbt_ctx.obj = { "manifest": self.manifest, "callbacks": self.callbacks, diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py index a01ee9b67e2..0cbcab5d1e8 100644 --- a/core/dbt/tests/util.py +++ b/core/dbt/tests/util.py @@ -70,28 +70,25 @@ # run_dbt(["run", "--vars", "seed_name: base"]) # If the command is expected to fail, pass in "expect_pass=False"): # run_dbt(["test"], expect_pass=False) -def run_dbt( - args: Optional[List[str]] = None, - expect_pass: bool = True, -): +def run_dbt(invocation_args: Optional[List[str]] = None, /, expect_pass: bool = True, **kwargs): # reset global vars reset_metadata_vars() - if args is None: - args = ["run"] + if invocation_args is None: + invocation_args = ["run"] - print("\n\nInvoking dbt with {}".format(args)) + print("\n\nInvoking dbt with {}".format(invocation_args)) from dbt.flags import get_flags flags = get_flags() project_dir = getattr(flags, "PROJECT_DIR", None) profiles_dir = getattr(flags, "PROFILES_DIR", None) - if project_dir and "--project-dir" not in args: - args.extend(["--project-dir", project_dir]) - if profiles_dir and "--profiles-dir" not in args: - args.extend(["--profiles-dir", profiles_dir]) + if project_dir and "--project-dir" not in invocation_args: + invocation_args.extend(["--project-dir", project_dir]) + if profiles_dir and "--profiles-dir" not in invocation_args: + invocation_args.extend(["--profiles-dir", profiles_dir]) dbt = dbtRunner() - res = dbt.invoke(args) + res = dbt.invoke(invocation_args, **kwargs) # the exception is immediately raised to be caught in tests # using a pattern like `with pytest.raises(SomeException):` @@ -109,13 +106,12 @@ def run_dbt( # start with the "--debug" flag. The structured schema log CI test # will turn the logs into json, so you have to be prepared for that. def run_dbt_and_capture( - args: Optional[List[str]] = None, - expect_pass: bool = True, + invocation_args: Optional[List[str]] = None, /, expect_pass: bool = True, **kwargs ): try: stringbuf = StringIO() capture_stdout_logs(stringbuf) - res = run_dbt(args, expect_pass=expect_pass) + res = run_dbt(invocation_args, expect_pass=expect_pass, **kwargs) stdout = stringbuf.getvalue() finally: diff --git a/tests/functional/colors/test_colors.py b/tests/functional/colors/test_colors.py index 3f731108d18..169c211a024 100644 --- a/tests/functional/colors/test_colors.py +++ b/tests/functional/colors/test_colors.py @@ -34,7 +34,7 @@ def test_no_use_colors(self, project): ) def assert_colors_used(self, flag, expect_colors): - _, stdout = run_dbt_and_capture(args=[flag, "run"], expect_pass=False) + _, stdout = run_dbt_and_capture([flag, "run"], expect_pass=False) # pattern to match formatted log output pattern = re.compile(r"\[31m.*|\[33m.*") stdout_contains_formatting_characters = bool(pattern.search(stdout)) diff --git a/tests/functional/list/test_list.py b/tests/functional/list/test_list.py index 653021c608b..cf96f31321d 100644 --- a/tests/functional/list/test_list.py +++ b/tests/functional/list/test_list.py @@ -23,7 +23,7 @@ def run_dbt_ls(self, args=None, expect_pass=True): full_args = ["ls"] if args is not None: full_args += args - result = run_dbt(args=full_args, expect_pass=expect_pass) + result = run_dbt(full_args, expect_pass=expect_pass) return result diff --git a/tests/functional/run_operations/fixtures.py b/tests/functional/run_operations/fixtures.py index f6ed82e20ec..7241d5fe2c1 100644 --- a/tests/functional/run_operations/fixtures.py +++ b/tests/functional/run_operations/fixtures.py @@ -52,8 +52,8 @@ {% endmacro %} -{% macro print_something() %} - {{ print("You're doing awesome!") }} +{% macro print_something(message="You're doing awesome!") %} + {{ print(message) }} {% endmacro %} """ diff --git a/tests/functional/run_operations/test_run_operations.py b/tests/functional/run_operations/test_run_operations.py index 064c98b3a51..b265bad0dcc 100644 --- a/tests/functional/run_operations/test_run_operations.py +++ b/tests/functional/run_operations/test_run_operations.py @@ -75,6 +75,12 @@ def test_macro_args(self, project): self.run_operation("table_name_args", table_name="my_fancy_table") check_table_does_exist(project.adapter, "my_fancy_table") + def test_args_as_keyword(self, project): + results, log_output = run_dbt_and_capture( + ["run-operation", "print_something"], args={"message": "Morning coffee"} + ) + assert "Morning coffee" in log_output + def test_macro_exception(self, project): self.run_operation("syntax_error", False) diff --git a/tests/functional/threading/test_thread_count.py b/tests/functional/threading/test_thread_count.py index 9c94356e630..82cbabff9d5 100644 --- a/tests/functional/threading/test_thread_count.py +++ b/tests/functional/threading/test_thread_count.py @@ -42,5 +42,5 @@ def profiles_config_update(self): return {"threads": 2} def test_threading_8x(self, project): - results = run_dbt(args=["run", "--threads", "16"]) + results = run_dbt(["run", "--threads", "16"]) assert len(results), 20