Skip to content

Commit

Permalink
fixup! Swap CeleryExecutor over to use TaskSDK for execution.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashb committed Jan 30, 2025
1 parent 459c746 commit 8e2d7cc
Show file tree
Hide file tree
Showing 19 changed files with 256 additions and 117 deletions.
45 changes: 34 additions & 11 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
# Command to execute - list of strings
# the first element is always "airflow".
# It should be result of TaskInstance.generate_command method.
CommandType = list[str]
CommandType = Sequence[str]

# Task that is queued. It contains all the information that is
# needed to run the task.
Expand Down Expand Up @@ -223,7 +223,12 @@ def has_task(self, task_instance: TaskInstance) -> bool:
:param task_instance: TaskInstance
:return: True if the task is known to this executor
"""
return task_instance.key in self.queued_tasks or task_instance.key in self.running
return (
task_instance.id in self.queued_tasks
or task_instance.id in self.running
or task_instance.key in self.queued_tasks
or task_instance.key in self.running
)

def sync(self) -> None:
"""
Expand Down Expand Up @@ -319,6 +324,20 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTa
:return: List of tuples from the queued_tasks according to the priority.
"""
from airflow.executors import workloads

if not self.queued_tasks:
return []

kind = next(iter(self.queued_tasks.values()))
if isinstance(kind, workloads.BaseWorkload):
# V3 + new executor that supports workloads
return sorted(
self.queued_tasks.items(),
key=lambda x: x[1].ti.priority_weight,
reverse=True,
)

return sorted(
self.queued_tasks.items(),
key=lambda x: x[1][1],
Expand All @@ -332,12 +351,12 @@ def trigger_tasks(self, open_slots: int) -> None:
:param open_slots: Number of open slots
"""
span = Trace.get_current_span()
sorted_queue = self.order_queued_tasks_by_priority()
task_tuples = []
workloads = []

for _ in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, ti) = sorted_queue.pop(0)
key, item = sorted_queue.pop(0)

# If a task makes it here but is still understood by the executor
# to be running, it generally means that the task has been killed
Expand Down Expand Up @@ -375,15 +394,19 @@ def trigger_tasks(self, open_slots: int) -> None:
else:
if key in self.attempts:
del self.attempts[key]
task_tuples.append((key, command, queue, ti.executor_config))
if span.is_recording():
span.add_event(
name="task to trigger",
attributes={"command": str(command), "conf": str(ti.executor_config)},
)
# TODO: TaskSDK: Compat, remove when KubeExecutor is fully moved over to TaskSDK too.
# TODO: TaskSDK: We need to minimum version requirements on executors with Airflow 3.
# How/where do we do that? Executor loader?
if hasattr(self, "_process_workloads"):
workloads.append(item)
else:
(command, _, queue, ti) = item
task_tuples.append((key, command, queue, getattr(ti, "executor_config", None)))

if task_tuples:
self._process_tasks(task_tuples)
elif workloads:
self._process_workloads(workloads) # type: ignore[attr-defined]

@add_span
def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
Expand Down Expand Up @@ -625,7 +648,7 @@ def slots_occupied(self):
return len(self.running) + len(self.queued_tasks)

@staticmethod
def validate_airflow_tasks_run_command(command: list[str]) -> tuple[str | None, str | None]:
def validate_airflow_tasks_run_command(command: Sequence[str]) -> tuple[str | None, str | None]:
"""
Check if the command to execute is airflow command.
Expand Down
4 changes: 2 additions & 2 deletions airflow/executors/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
]


class BaseActivity(BaseModel):
class BaseWorkload(BaseModel):
token: str
"""The identity token for this workload"""

Expand Down Expand Up @@ -75,7 +75,7 @@ def key(self) -> TaskInstanceKey:
)


class ExecuteTask(BaseActivity):
class ExecuteTask(BaseWorkload):
"""Execute the given Task."""

ti: TaskInstance
Expand Down
4 changes: 3 additions & 1 deletion airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,6 @@ def process_executor_events(
)
if info is not None:
msg += " Extra info: %s" % info # noqa: RUF100, UP031, flynt
cls.logger().error(msg)
session.add(Log(event="state mismatch", extra=msg, task_instance=ti.key))

# Get task from the Serialized DAG
Expand All @@ -849,6 +848,9 @@ def process_executor_events(
continue
ti.task = task
if task.on_retry_callback or task.on_failure_callback:
# Only log the error/extra info here, since the `ti.handle_failure()` path will log it
# too, which would lead to double logging
cls.logger().error(msg)
request = TaskCallbackRequest(
full_filepath=ti.dag_model.fileloc,
ti=ti,
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
"celery": {
"deps": [
"apache-airflow>=2.9.0",
"celery[redis]>=5.3.0,<6,!=5.3.3,!=5.3.2",
"celery[redis]>=5.4.0,<6",
"flower>=1.0.0",
"google-re2>=1.0"
],
Expand Down
8 changes: 4 additions & 4 deletions providers/celery/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ The package supports the following python versions: 3.9,3.10,3.11,3.12
Requirements
------------

================== ==============================
================== ==================
PIP package Version required
================== ==============================
================== ==================
``apache-airflow`` ``>=2.9.0``
``celery[redis]`` ``>=5.3.0,!=5.3.2,!=5.3.3,<6``
``celery[redis]`` ``>=5.4.0,<6``
``flower`` ``>=1.0.0``
``google-re2`` ``>=1.0``
================== ==============================
================== ==================

Cross provider package dependencies
-----------------------------------
Expand Down
4 changes: 1 addition & 3 deletions providers/celery/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ dependencies = [
# The Celery is known to introduce problems when upgraded to a MAJOR version. Airflow Core
# Uses Celery for CeleryExecutor, and we also know that Kubernetes Python client follows SemVer
# (https://docs.celeryq.dev/en/stable/contributing.html?highlight=semver#versions).
# Make sure that the limit here is synchronized with [celery] extra in the airflow core
# The 5.3.3/5.3.2 limit comes from https://github.com/celery/celery/issues/8470
"celery[redis]>=5.3.0,<6,!=5.3.3,!=5.3.2",
"celery[redis]>=5.4.0,<6",
"flower>=1.0.0",
"google-re2>=1.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def worker(args):
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.log import configure_logging

configure_logging(output=sys.stdout.buffer)

# Disable connection pool so that celery worker does not hold an unnecessary db connection
settings.reconfigure_orm(disable_connection_pool=True)
if not settings.validate_session():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from collections.abc import Sequence
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any

from deprecated import deprecated

Expand All @@ -53,7 +53,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.celery.version_compat import AIRFLOW_V_2_8_PLUS
from airflow.providers.celery.version_compat import AIRFLOW_V_2_8_PLUS, AIRFLOW_V_3_0_PLUS
from airflow.stats import Stats
from airflow.utils.state import TaskInstanceState
from celery import states as celery_states
Expand All @@ -67,14 +67,13 @@
if TYPE_CHECKING:
import argparse

from airflow.executors.base_executor import CommandType, TaskTuple
from sqlalchemy.orm import Session

from airflow.executors import workloads
from airflow.executors.base_executor import TaskTuple
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from celery import Task

# Task instance that is sent over Celery queues
# TaskInstanceKey, Command, queue_name, CallableTask
TaskInstanceInCelery = tuple[TaskInstanceKey, CommandType, Optional[str], Task]
from airflow.providers.celery.executors.celery_executor_utils import TaskInstanceInCelery


# PEP562
Expand Down Expand Up @@ -228,6 +227,11 @@ class CeleryExecutor(BaseExecutor):
supports_ad_hoc_ti_run: bool = True
supports_sentry: bool = True

if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
# In the v3 path, we store workloads, not commands as strings.
# TODO: TaskSDK: move this type change into BaseExecutor
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]

def __init__(self):
super().__init__()

Expand Down Expand Up @@ -256,10 +260,22 @@ def _num_tasks_per_send_process(self, to_send_count: int) -> int:
return max(1, math.ceil(to_send_count / self._sync_parallelism))

def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
# Airflow V2 version
from airflow.providers.celery.executors.celery_executor_utils import execute_command

task_tuples_to_send = [task_tuple[:3] + (execute_command,) for task_tuple in task_tuples]
first_task = next(t[3] for t in task_tuples_to_send)

self._send_tasks(task_tuples_to_send)

def _process_workloads(self, workloads: list[workloads.All]) -> None:
# Airflow V3 version
from airflow.providers.celery.executors.celery_executor_utils import execute_workload

tasks = [(workload.ti.key, workload, workload.ti.queue, execute_workload) for workload in workloads]
self._send_tasks(tasks)

def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]):
first_task = next(t[-1] for t in task_tuples_to_send)

# Celery state queries will stuck if we do not use one same backend
# for all tasks.
Expand Down Expand Up @@ -299,7 +315,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
# which point we don't need the ID anymore anyway
self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)

def _send_tasks_to_celery(self, task_tuples_to_send: list[TaskInstanceInCelery]):
def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]):
from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor

if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
Expand Down Expand Up @@ -359,7 +375,7 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None
self.success(key, info)
elif state in (celery_states.FAILURE, celery_states.REVOKED):
self.fail(key, info)
elif state in (celery_states.STARTED, celery_states.PENDING):
elif state in (celery_states.STARTED, celery_states.PENDING, celery_states.RETRY):
pass
else:
self.log.info("Unexpected state for %s: %s", key, state)
Expand Down Expand Up @@ -416,6 +432,10 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
for celery_task_id, (state, info) in states_by_celery_task_id.items():
result, ti = celery_tasks[celery_task_id]
result.backend = cached_celery_backend
if isinstance(result.result, BaseException):
e = result.result
# Log the exception we got from the remote end
self.log.warning("Task %s failed with error", ti.key, exc_info=e)

# Set the correct elements of the state dicts, then update this
# like we just queried it.
Expand Down Expand Up @@ -475,6 +495,10 @@ def get_cli_commands() -> list[GroupCommand]:
),
]

def queue_workload(self, workload: workloads.ExecuteTask, session: Session | None) -> None:
ti = workload.ti
self.queued_tasks[ti.key] = workload


def _get_parser() -> argparse.ArgumentParser:
"""
Expand Down
Loading

0 comments on commit 8e2d7cc

Please sign in to comment.