Skip to content

Commit

Permalink
Merge branch 'main' into make-logical-date-as-required-field-to-trigg…
Browse files Browse the repository at this point in the history
…er-dag-api
  • Loading branch information
vatsrahul1001 authored Feb 7, 2025
2 parents 5467412 + 9689cf5 commit 3da8eae
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 73 deletions.
70 changes: 34 additions & 36 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,52 @@
from __future__ import annotations

import sys
from collections import deque
from typing import Callable, TypeVar

import libcst as cst

T = TypeVar("T", bound=Callable)


class _TaskDecoratorRemover(cst.CSTTransformer):
def __init__(self, task_decorator_name: str) -> None:
self.decorators_to_remove: set[str] = {
"setup",
"teardown",
"task.skip_if",
"task.run_if",
task_decorator_name.strip("@"),
}

def _is_task_decorator(self, decorator_node: cst.Decorator) -> bool:
decorator_expr = decorator_node.decorator
if isinstance(decorator_expr, cst.Name):
return decorator_expr.value in self.decorators_to_remove
elif isinstance(decorator_expr, cst.Attribute) and isinstance(decorator_expr.value, cst.Name):
return f"{decorator_expr.value.value}.{decorator_expr.attr.value}" in self.decorators_to_remove
elif isinstance(decorator_expr, cst.Call):
return self._is_task_decorator(cst.Decorator(decorator=decorator_expr.func))
return False

def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)]
if len(new_decorators) == len(updated_node.decorators):
return updated_node
return updated_node.with_changes(decorators=new_decorators)


def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
"""
Remove @task or similar decorators as well as @setup and @teardown.
:param python_source: The python source code
:param task_decorator_name: the decorator name
TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
"""

def _remove_task_decorator(py_source, decorator_name):
# if no line starts with @decorator_name, we can early exit
for line in py_source.split("\n"):
if line.startswith(decorator_name):
break
else:
return python_source
split = python_source.split(decorator_name, 1)
before_decorator, after_decorator = split[0], split[1]
if after_decorator[0] == "(":
after_decorator = _balance_parens(after_decorator)
if after_decorator[0] == "\n":
after_decorator = after_decorator[1:]
return before_decorator + after_decorator

decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name]
for decorator in decorators:
python_source = _remove_task_decorator(python_source, decorator)
return python_source


def _balance_parens(after_decorator):
num_paren = 1
after_decorator = deque(after_decorator)
after_decorator.popleft()
while num_paren:
current = after_decorator.popleft()
if current == "(":
num_paren = num_paren + 1
elif current == ")":
num_paren = num_paren - 1
return "".join(after_decorator)
source_tree = cst.parse_module(python_source)
modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name))
return modified_tree.code


class _autostacklevel_warn:
Expand Down
1 change: 1 addition & 0 deletions hatch_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@
"jinja2>=3.0.0",
"jsonschema>=4.18.0",
"lazy-object-proxy>=1.2.0",
"libcst >=1.1.0",
"linkify-it-py>=2.0.0",
"lockfile>=0.12.2",
"markdown-it-py>=2.1.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from pathlib import Path
from textwrap import dedent
from unittest import mock

import pytest
Expand Down Expand Up @@ -191,26 +192,29 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s
["uv", "pip", "install", "--python", "/VENV/bin/python", "apache-beam[gcp]"]
)

def test_remove_task_decorator(self):
py_source = '@task.virtualenv(serializer="dill")\ndef f():\nimport funcsigs'
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\nimport funcsigs"

def test_remove_decorator_no_parens(self):
py_source = "@task.virtualenv\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\nimport funcsigs"

def test_remove_decorator_including_comment(self):
py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\n# @task.virtualenv\nimport funcsigs"

def test_remove_decorator_nested(self):
py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
@pytest.mark.parametrize(
"decorators, expected_decorators",
[
(["@task.virtualenv"], []),
(["@task.virtualenv()"], []),
(['@task.virtualenv(serializer="dill")'], []),
(["@foo", "@task.virtualenv", "@bar"], ["@foo", "@bar"]),
(["@foo", "@task.virtualenv()", "@bar"], ["@foo", "@bar"]),
],
ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"],
)
def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]):
concated_decorators = "\n".join(decorators)
expected_decorator = "\n".join(expected_decorators)
SCRIPT = dedent(
"""
def f():
# @task.virtualenv
import funcsigs
"""
)
py_source = concated_decorators + SCRIPT
expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip()

py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == expected_source
43 changes: 27 additions & 16 deletions tests/utils/test_preexisting_python_virtualenv_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,36 @@
# under the License.
from __future__ import annotations

from airflow.utils.decorators import remove_task_decorator
from textwrap import dedent

import pytest

class TestExternalPythonDecorator:
def test_remove_task_decorator(self):
py_source = '@task.external_python(serializer="dill")\ndef f():\nimport funcsigs'
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\nimport funcsigs"
from airflow.utils.decorators import remove_task_decorator

def test_remove_decorator_no_parens(self):
py_source = "@task.external_python\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\nimport funcsigs"

def test_remove_decorator_nested(self):
py_source = "@foo\n@task.external_python\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
class TestExternalPythonDecorator:
@pytest.mark.parametrize(
"decorators, expected_decorators",
[
(["@task.external_python"], []),
(["@task.external_python()"], []),
(['@task.external_python(serializer="dill")'], []),
(["@foo", "@task.external_python", "@bar"], ["@foo", "@bar"]),
(["@foo", "@task.external_python()", "@bar"], ["@foo", "@bar"]),
],
ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"],
)
def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]):
concated_decorators = "\n".join(decorators)
expected_decorator = "\n".join(expected_decorators)
SCRIPT = dedent(
"""
def f():
import funcsigs
"""
)
py_source = concated_decorators + SCRIPT
expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip()

py_source = "@foo\n@task.external_python()\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == expected_source

0 comments on commit 3da8eae

Please sign in to comment.