diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 78044e4e35761..9308f0eccb69a 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -18,54 +18,55 @@ from __future__ import annotations import sys -from collections import deque from typing import Callable, TypeVar +import libcst as cst + T = TypeVar("T", bound=Callable) +class _TaskDecoratorRemover(cst.CSTTransformer): + def __init__(self, task_decorator_name): + self.decorators_to_remove = { + "setup", + "teardown", + "task.skip_if", + "task.run_if", + task_decorator_name.strip("@"), + } + + def _is_task_decorator(self, decorator: cst.Decorator) -> bool: + if isinstance(decorator.decorator, cst.Name): + return decorator.decorator.value in self.decorators_to_remove + elif isinstance(decorator.decorator, cst.Attribute): + if isinstance(decorator.decorator.value, cst.Name): + return ( + f"{decorator.decorator.value.value}.{decorator.decorator.attr.value}" + in self.decorators_to_remove + ) + elif isinstance(decorator.decorator, cst.Call): + return self._is_task_decorator(cst.Decorator(decorator=decorator.decorator.func)) + return False + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)] + if len(new_decorators) == len(updated_node.decorators): + return updated_node + return updated_node.with_changes(decorators=new_decorators) + + def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: """ Remove @task or similar decorators as well as @setup and @teardown. :param python_source: The python source code :param task_decorator_name: the decorator name - - TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse """ - - def _remove_task_decorator(py_source, decorator_name): - # if no line starts with @decorator_name, we can early exit - for line in py_source.split("\n"): - if line.startswith(decorator_name): - break - else: - return python_source - split = python_source.split(decorator_name, 1) - before_decorator, after_decorator = split[0], split[1] - if after_decorator[0] == "(": - after_decorator = _balance_parens(after_decorator) - if after_decorator[0] == "\n": - after_decorator = after_decorator[1:] - return before_decorator + after_decorator - - decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] - for decorator in decorators: - python_source = _remove_task_decorator(python_source, decorator) - return python_source - - -def _balance_parens(after_decorator): - num_paren = 1 - after_decorator = deque(after_decorator) - after_decorator.popleft() - while num_paren: - current = after_decorator.popleft() - if current == "(": - num_paren = num_paren + 1 - elif current == ")": - num_paren = num_paren - 1 - return "".join(after_decorator) + source_tree = cst.parse_module(python_source) + modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name)) + return modified_tree.code class _autostacklevel_warn: diff --git a/hatch_build.py b/hatch_build.py index cb0309c942981..ba456d710fad6 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -391,6 +391,7 @@ "jinja2>=3.0.0", "jsonschema>=4.18.0", "lazy-object-proxy>=1.2.0", + "libcst >=1.1.0", "linkify-it-py>=2.0.0", "lockfile>=0.12.2", "markdown-it-py>=2.1.0", diff --git a/providers/tests/standard/utils/test_python_virtualenv.py b/providers/tests/standard/utils/test_python_virtualenv.py index d1a7eef94ec26..23a8d1a3cd0a9 100644 --- a/providers/tests/standard/utils/test_python_virtualenv.py +++ b/providers/tests/standard/utils/test_python_virtualenv.py @@ -192,25 +192,25 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s ) def test_remove_task_decorator(self): - py_source = '@task.virtualenv(serializer="dill")\ndef f():\nimport funcsigs' + py_source = '@task.virtualenv(serializer="dill")\ndef f():\n import funcsigs' res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\nimport funcsigs" + assert res == "def f():\n import funcsigs" def test_remove_decorator_no_parens(self): - py_source = "@task.virtualenv\ndef f():\nimport funcsigs" + py_source = "@task.virtualenv\ndef f():\n import funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\nimport funcsigs" + assert res == "def f():\n import funcsigs" def test_remove_decorator_including_comment(self): - py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs" + py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\n import funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\n# @task.virtualenv\nimport funcsigs" + assert res == "def f():\n# @task.virtualenv\n import funcsigs" def test_remove_decorator_nested(self): - py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\n import funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n import funcsigs" - py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\n import funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n import funcsigs" diff --git a/tests/utils/test_preexisting_python_virtualenv_decorator.py b/tests/utils/test_preexisting_python_virtualenv_decorator.py index 11d80e348ea81..15677e844d7d4 100644 --- a/tests/utils/test_preexisting_python_virtualenv_decorator.py +++ b/tests/utils/test_preexisting_python_virtualenv_decorator.py @@ -22,20 +22,20 @@ class TestExternalPythonDecorator: def test_remove_task_decorator(self): - py_source = '@task.external_python(serializer="dill")\ndef f():\nimport funcsigs' + py_source = '@task.external_python(serializer="dill")\ndef f():\n import funcsigs' res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "def f():\nimport funcsigs" + assert res == "def f():\n import funcsigs" def test_remove_decorator_no_parens(self): - py_source = "@task.external_python\ndef f():\nimport funcsigs" + py_source = "@task.external_python\ndef f():\n import funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "def f():\nimport funcsigs" + assert res == "def f():\n import funcsigs" def test_remove_decorator_nested(self): - py_source = "@foo\n@task.external_python\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.external_python\n@bar\ndef f():\n import funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n import funcsigs" - py_source = "@foo\n@task.external_python()\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.external_python()\n@bar\ndef f():\n import funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n import funcsigs"