Skip to content

Commit

Permalink
fixup! refactor(utils/decorators): rewrite remove task decorator to u…
Browse files Browse the repository at this point in the history
…se ast
  • Loading branch information
josix authored and potiuk committed Feb 6, 2025
1 parent d451b16 commit 4c515fa
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 121 deletions.
16 changes: 8 additions & 8 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ def __init__(self, task_decorator_name: str) -> None:
task_decorator_name.strip("@"),
}

def _is_task_decorator(self, decorator: cst.Decorator) -> bool:
if isinstance(decorator.decorator, cst.Name):
return decorator.decorator.value in self.decorators_to_remove
elif isinstance(decorator.decorator, cst.Attribute) and isinstance(
decorator.decorator.value, cst.Name
def _is_task_decorator(self, decorator_node: cst.Decorator) -> bool:
if isinstance(decorator_node.decorator, cst.Name):
return decorator_node.decorator.value in self.decorators_to_remove
elif isinstance(decorator_node.decorator, cst.Attribute) and isinstance(
decorator_node.decorator.value, cst.Name
):
return (
f"{decorator.decorator.value.value}.{decorator.decorator.attr.value}"
f"{decorator_node.decorator.value.value}.{decorator_node.decorator.attr.value}"
in self.decorators_to_remove
)
elif isinstance(decorator.decorator, cst.Call):
return self._is_task_decorator(cst.Decorator(decorator=decorator.decorator.func))
elif isinstance(decorator_node.decorator, cst.Call):
return self._is_task_decorator(cst.Decorator(decorator=decorator_node.decorator.func))
return False

def leave_FunctionDef(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,78 +192,29 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s
["uv", "pip", "install", "--python", "/VENV/bin/python", "apache-beam[gcp]"]
)

def test_remove_task_decorator(self):
py_source = dedent(
"""
@task.virtualenv(serializer="dill")
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == expected_source

def test_remove_decorator_no_parens(self):
py_source = dedent(
"""
@task.virtualenv
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == expected_source

def test_remove_decorator_including_comment(self):
py_source = dedent(
"""
@task.virtualenv
def f():
# @task.virtualenv
import funcsigs
"""
)
expected_source = dedent(
@pytest.mark.parametrize(
"decorators, expected_decorators",
[
(["@task.virtualenv"], []),
(["@task.virtualenv()"], []),
(['@task.virtualenv(serializer="dill")'], []),
(["@foo", "@task.virtualenv", "@bar"], ["@foo", "@bar"]),
(["@foo", "@task.virtualenv()", "@bar"], ["@foo", "@bar"]),
],
ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"],
)
def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]):
decorator = "\n".join(decorators)
expected_decorator = "\n".join(expected_decorators)
SCRIPT = dedent(
"""
def f():
# @task.virtualenv
import funcsigs
"""
)
py_source = decorator + SCRIPT
expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip()

res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == expected_source

@pytest.mark.parametrize("decorator", ["@task.virtualenv", "@task.virtualenv()"])
def test_remove_decorator_nested(self, decorator):
py_source = dedent(
f"""
@foo
{decorator}
@bar
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
@foo
@bar
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == expected_source
64 changes: 17 additions & 47 deletions tests/utils/test_preexisting_python_virtualenv_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,58 +25,28 @@


class TestExternalPythonDecorator:
def test_remove_task_decorator(self):
py_source = dedent(
@pytest.mark.parametrize(
"decorators, expected_decorators",
[
(["@task.external_python"], []),
(["@task.external_python()"], []),
(['@task.external_python(serializer="dill")'], []),
(["@foo", "@task.external_python", "@bar"], ["@foo", "@bar"]),
(["@foo", "@task.external_python()", "@bar"], ["@foo", "@bar"]),
],
ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"],
)
def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]):
decorator = "\n".join(decorators)
expected_decorator = "\n".join(expected_decorators)
SCRIPT = dedent(
"""
@task.external_python(serializer="dill")
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == expected_source
py_source = decorator + SCRIPT
expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip()

def test_remove_decorator_no_parens(self):
py_source = dedent(
"""
@task.external_python
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == expected_source

@pytest.mark.parametrize("decorator", ["@task.external_python", "@task.external_python()"])
def test_remove_decorator_nested(self, decorator):
py_source = dedent(
f"""
@foo
{decorator}
@bar
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
@foo
@bar
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == expected_source

0 comments on commit 4c515fa

Please sign in to comment.