Skip to content

Commit

Permalink
Speed up dag.clear() when clearing lots of ExternalTaskSensor and E…
Browse files Browse the repository at this point in the history
…xternalTaskMarker (apache#11184)

This is an improvement to the UI response time when clearing dozens of DagRuns of large DAGs (thousands of tasks) containing many ExternalTaskSensor + ExternalTaskMarker pairs. In the current implementation, clearing tasks can get slow especially if the user chooses to clear with Future, Downstream and Recursive all selected.

This PR speeds it up. There are two major improvements:

Updating self._task_group in dag.sub_dag() is improved to not deep copy _task_group because it's a waste of time. Instead, do something like dag.task_dict, set it to None first and then copy explicitly.
Pass the TaskInstance already visited down the recursive calls of dag.clear() as visited_external_tis. This speeds up the example in test_clear_overlapping_external_task_marker by almost five folds.
For real large dags containing 500 tasks set up in a similar manner, the time it takes to clear 30 DagRun is cut from around 100s to less than 10s.
  • Loading branch information
yuqian90 authored Oct 22, 2020
1 parent 727c739 commit 4f2e0cf
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 61 deletions.
145 changes: 84 additions & 61 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ def clear(
recursion_depth=0,
max_recursion_depth=None,
dag_bag=None,
visited_external_tis=None,
):
"""
Clears a set of task instances associated with the current dag for
Expand Down Expand Up @@ -1138,6 +1139,9 @@ def clear(
:type max_recursion_depth: int
:param dag_bag: The DagBag used to find the dags
:type dag_bag: airflow.models.dagbag.DagBag
:param visited_external_tis: A set used internally to keep track of the visited TaskInstance when
clearing tasks across multiple DAGs linked by ExternalTaskMarker to avoid redundant work.
:type visited_external_tis: set
"""
TI = TaskInstance
tis = session.query(TI)
Expand Down Expand Up @@ -1172,7 +1176,8 @@ def clear(
session=session,
recursion_depth=recursion_depth,
max_recursion_depth=max_recursion_depth,
dag_bag=dag_bag
dag_bag=dag_bag,
visited_external_tis=visited_external_tis
))

if start_date:
Expand All @@ -1193,51 +1198,60 @@ def clear(
instances = tis.all()
for ti in instances:
if ti.operator == ExternalTaskMarker.__name__:
task: ExternalTaskMarker = cast(ExternalTaskMarker, copy.copy(self.get_task(ti.task_id)))
ti.task = task

if recursion_depth == 0:
# Maximum recursion depth allowed is the recursion_depth of the first
# ExternalTaskMarker in the tasks to be cleared.
max_recursion_depth = task.recursion_depth

if recursion_depth + 1 > max_recursion_depth:
# Prevent cycles or accidents.
raise AirflowException("Maximum recursion depth {} reached for {} {}. "
"Attempted to clear too many tasks "
"or there may be a cyclic dependency."
.format(max_recursion_depth,
ExternalTaskMarker.__name__, ti.task_id))
ti.render_templates()
external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id,
TI.task_id == task.external_task_id,
TI.execution_date ==
pendulum.parse(task.execution_date))

for tii in external_tis:
if not dag_bag:
dag_bag = DagBag()
external_dag = dag_bag.get_dag(tii.dag_id)
if not external_dag:
raise AirflowException("Could not find dag {}".format(tii.dag_id))
downstream = external_dag.sub_dag(
task_regex=r"^{}$".format(tii.task_id),
include_upstream=False,
include_downstream=True
)
tis = tis.union(downstream.clear(start_date=tii.execution_date,
end_date=tii.execution_date,
only_failed=only_failed,
only_running=only_running,
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth + 1,
max_recursion_depth=max_recursion_depth,
dag_bag=dag_bag))
if visited_external_tis is None:
visited_external_tis = set()
ti_key = ti.key.primary
if ti_key not in visited_external_tis:
# Only clear this ExternalTaskMarker if it's not already visited by the
# recursive calls to dag.clear().
task: ExternalTaskMarker = cast(ExternalTaskMarker,
copy.copy(self.get_task(ti.task_id)))
ti.task = task

if recursion_depth == 0:
# Maximum recursion depth allowed is the recursion_depth of the first
# ExternalTaskMarker in the tasks to be cleared.
max_recursion_depth = task.recursion_depth

if recursion_depth + 1 > max_recursion_depth:
# Prevent cycles or accidents.
raise AirflowException("Maximum recursion depth {} reached for {} {}. "
"Attempted to clear too many tasks "
"or there may be a cyclic dependency."
.format(max_recursion_depth,
ExternalTaskMarker.__name__, ti.task_id))
ti.render_templates()
external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id,
TI.task_id == task.external_task_id,
TI.execution_date ==
pendulum.parse(task.execution_date))

for tii in external_tis:
if not dag_bag:
dag_bag = DagBag(read_dags_from_db=True)
external_dag = dag_bag.get_dag(tii.dag_id)
if not external_dag:
raise AirflowException("Could not find dag {}".format(tii.dag_id))
downstream = external_dag.sub_dag(
task_regex=r"^{}$".format(tii.task_id),
include_upstream=False,
include_downstream=True
)
tis = tis.union(downstream.clear(start_date=tii.execution_date,
end_date=tii.execution_date,
only_failed=only_failed,
only_running=only_running,
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth + 1,
max_recursion_depth=max_recursion_depth,
dag_bag=dag_bag,
visited_external_tis=visited_external_tis))
visited_external_tis.add(ti_key)

if get_tis:
return tis
Expand Down Expand Up @@ -1375,12 +1389,15 @@ def partial_subset(
based on a regex that should match one or many tasks, and includes
upstream and downstream neighbours based on the flag passed.
"""
# deep-copying self.task_dict takes a long time, and we don't want all
# deep-copying self.task_dict and self._task_group takes a long time, and we don't want all
# the tasks anyway, so we copy the tasks manually later
task_dict = self.task_dict
task_group = self._task_group
self.task_dict = {}
self._task_group = None
dag = copy.deepcopy(self)
self.task_dict = task_dict
self._task_group = task_group

regex_match = [
t for t in self.tasks if re.findall(task_regex, t.task_id)]
Expand All @@ -1396,24 +1413,30 @@ def partial_subset(
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
for t in regex_match + also_include}

# Remove tasks not included in the subdag from task_group
def remove_excluded(group):
for child in list(group.children.values()):
def filter_task_group(group, parent_group):
"""
Exclude tasks not included in the subdag from the given TaskGroup.
"""
copied = copy.copy(group)
copied.used_group_ids = set(copied.used_group_ids)
copied._parent_group = parent_group

copied.children = {}

for child in group.children.values():
if isinstance(child, BaseOperator):
if child.task_id not in dag.task_dict:
group.children.pop(child.task_id)
else:
# The tasks in the subdag are a copy of tasks in the original dag
# so update the reference in the TaskGroups too.
group.children[child.task_id] = dag.task_dict[child.task_id]
if child.task_id in dag.task_dict:
copied.children[child.task_id] = dag.task_dict[child.task_id]
else:
remove_excluded(child)
filtered_child = filter_task_group(child, copied)

# Only include this child TaskGroup if it is non-empty.
if filtered_child.children:
copied.children[child.group_id] = filtered_child

# Remove this TaskGroup if it doesn't contain any tasks in this subdag
if not child.children:
group.children.pop(child.group_id)
return copied

remove_excluded(dag.task_group)
dag._task_group = filter_task_group(self._task_group, None)

# Removing upstream/downstream references to tasks and TaskGroups that did not make
# the cut.
Expand Down
54 changes: 54 additions & 0 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,57 @@ def test_clear_multiple_external_task_marker(dag_bag_multiple):
# That has since been fixed. It should take no more than a few seconds to call
# dag.clear() here.
assert agg_dag.clear(start_date=execution_date, end_date=execution_date, dag_bag=dag_bag_multiple) == 51


@pytest.fixture
def dag_bag_head_tail():
"""
Create a DagBag containing one DAG, with task "head" depending on task "tail" of the
previous execution_date.
20200501 20200502 20200510
+------+ +------+ +------+
| head | -->head | --> -->head |
| | | / | | | / / | | |
| v | / | v | / / | v |
| body | / | body | / ... / | body |
| | |/ | | |/ / | | |
| v / | v / / | v |
| tail/| | tail/| / | tail |
+------+ +------+ +------+
"""
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
with DAG("head_tail", start_date=DEFAULT_DATE, schedule_interval="@daily") as dag:
head = ExternalTaskSensor(task_id='head',
external_dag_id=dag.dag_id,
external_task_id="tail",
execution_delta=timedelta(days=1),
mode="reschedule")
body = DummyOperator(task_id="body")
tail = ExternalTaskMarker(task_id="tail",
external_dag_id=dag.dag_id,
external_task_id=head.task_id,
execution_date="{{ tomorrow_ds_nodash }}")
head >> body >> tail

dag_bag.bag_dag(dag=dag, root_dag=dag)

yield dag_bag


def test_clear_overlapping_external_task_marker(dag_bag_head_tail):
dag = dag_bag_head_tail.get_dag("head_tail")

# Mark first head task success.
first = TaskInstance(task=dag.get_task("head"), execution_date=DEFAULT_DATE)
first.run(mark_success=True)

for delta in range(10):
execution_date = DEFAULT_DATE + timedelta(days=delta)
run_tasks(dag_bag_head_tail, execution_date=execution_date)

# The next two lines are doing the same thing. Clearing the first "head" with "Future"
# selected is the same as not selecting "Future". They should take similar amount of
# time too because dag.clear() uses visited_external_tis to keep track of visited ExternalTaskMarker.
assert dag.clear(start_date=DEFAULT_DATE, dag_bag=dag_bag_head_tail) == 30
assert dag.clear(start_date=DEFAULT_DATE, end_date=execution_date, dag_bag=dag_bag_head_tail) == 30

0 comments on commit 4f2e0cf

Please sign in to comment.