From 15aec32f061055a78ec532336f15f9203dda115a Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 30 Jan 2025 13:33:45 +0000 Subject: [PATCH] Swap CeleryExecutor over to use TaskSDK for execution. Some points of note about this change: - Logging is changed in Celery, but only for Airflow 3 Celery does it's own "capture stdout" logging, which conflicts with the ones we do in the TaskSDK, so we disable that; but to not change anything for Airflow 3. - Simplify task SDK logging redirection As part of this discovery that Celery captures stdout/stderr itself (and before disabling that) I discovered a simpler way to re-open the stdin/out/err so that the implementation needs fewer/no special casing. - Make JSON task logs more readable by giving them a consistent/useful order We re-order (by re-creating) the event_dict so that timestamp, level, and then even are always the first items in the dict - Makes the CeleryExecutor understand the concept of "workloads" instead a command tuple. This change isn't done in the best way, but until Kube executor is swapped over (and possibly the other in-tree executors, such as ECS) we need to support both styles concurrently. The change should be done in such a way that the provider still works with Airflow v2, if it's running on that version. - Upgrade Celery This turned out to not be 100% necessary but it does fix some deprecation warnings when running on Python 3.12 - Ensure that the forked process in TaskSDK _never ever_ exits Again, this isn't possible usually, but since the setup step of `_fork_main` died, it didn't call `os._exit()`, and was caught further up, which meant the process stayed alive as it never closed the sockets properly. We put and extra safety try/except block in place to catch that I have not yet included a newsfragment for changing the executor interface as the old style is _currently_ still supported. --- airflow/executors/base_executor.py | 45 +++++-- airflow/executors/workloads.py | 4 +- airflow/jobs/scheduler_job_runner.py | 4 +- generated/provider_dependencies.json | 2 +- providers/celery/README.rst | 8 +- providers/celery/pyproject.toml | 4 +- .../providers/celery/cli/celery_command.py | 5 + .../celery/executors/celery_executor.py | 48 +++++-- .../celery/executors/celery_executor_utils.py | 88 +++++++++--- .../executors/celery_kubernetes_executor.py | 4 +- .../celery/executors/default_celery.py | 15 ++- .../providers/celery/get_provider_info.py | 2 +- .../celery/executors/test_celery_executor.py | 18 ++- .../aws/executors/batch/batch_executor.py | 4 +- .../amazon/aws/executors/ecs/utils.py | 3 +- .../executors/kubernetes_executor_utils.py | 2 +- .../executors/batch/test_batch_executor.py | 16 +-- .../airflow/sdk/execution_time/supervisor.py | 52 ++++---- task_sdk/src/airflow/sdk/log.py | 53 +++++--- .../executors/test_celery_executor.py | 126 ++++++++---------- 20 files changed, 313 insertions(+), 190 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index a0f48c74b13560..765623d8c94279 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -59,7 +59,7 @@ # Command to execute - list of strings # the first element is always "airflow". # It should be result of TaskInstance.generate_command method. - CommandType = list[str] + CommandType = Sequence[str] # Task that is queued. It contains all the information that is # needed to run the task. @@ -223,7 +223,12 @@ def has_task(self, task_instance: TaskInstance) -> bool: :param task_instance: TaskInstance :return: True if the task is known to this executor """ - return task_instance.key in self.queued_tasks or task_instance.key in self.running + return ( + task_instance.id in self.queued_tasks + or task_instance.id in self.running + or task_instance.key in self.queued_tasks + or task_instance.key in self.running + ) def sync(self) -> None: """ @@ -319,6 +324,20 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTa :return: List of tuples from the queued_tasks according to the priority. """ + from airflow.executors import workloads + + if not self.queued_tasks: + return [] + + kind = next(iter(self.queued_tasks.values())) + if isinstance(kind, workloads.BaseWorkload): + # V3 + new executor that supports workloads + return sorted( + self.queued_tasks.items(), + key=lambda x: x[1].ti.priority_weight, + reverse=True, + ) + return sorted( self.queued_tasks.items(), key=lambda x: x[1][1], @@ -332,12 +351,12 @@ def trigger_tasks(self, open_slots: int) -> None: :param open_slots: Number of open slots """ - span = Trace.get_current_span() sorted_queue = self.order_queued_tasks_by_priority() task_tuples = [] + workloads = [] for _ in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, queue, ti) = sorted_queue.pop(0) + key, item = sorted_queue.pop(0) # If a task makes it here but is still understood by the executor # to be running, it generally means that the task has been killed @@ -375,15 +394,19 @@ def trigger_tasks(self, open_slots: int) -> None: else: if key in self.attempts: del self.attempts[key] - task_tuples.append((key, command, queue, ti.executor_config)) - if span.is_recording(): - span.add_event( - name="task to trigger", - attributes={"command": str(command), "conf": str(ti.executor_config)}, - ) + # TODO: TaskSDK: Compat, remove when KubeExecutor is fully moved over to TaskSDK too. + # TODO: TaskSDK: We need to minimum version requirements on executors with Airflow 3. + # How/where do we do that? Executor loader? + if hasattr(self, "_process_workloads"): + workloads.append(item) + else: + (command, _, queue, ti) = item + task_tuples.append((key, command, queue, getattr(ti, "executor_config", None))) if task_tuples: self._process_tasks(task_tuples) + elif workloads: + self._process_workloads(workloads) # type: ignore[attr-defined] @add_span def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: @@ -625,7 +648,7 @@ def slots_occupied(self): return len(self.running) + len(self.queued_tasks) @staticmethod - def validate_airflow_tasks_run_command(command: list[str]) -> tuple[str | None, str | None]: + def validate_airflow_tasks_run_command(command: Sequence[str]) -> tuple[str | None, str | None]: """ Check if the command to execute is airflow command. diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py index 4c3eebe6811b9b..f3288e2ae219cf 100644 --- a/airflow/executors/workloads.py +++ b/airflow/executors/workloads.py @@ -34,7 +34,7 @@ ] -class BaseActivity(BaseModel): +class BaseWorkload(BaseModel): token: str """The identity token for this workload""" @@ -75,7 +75,7 @@ def key(self) -> TaskInstanceKey: ) -class ExecuteTask(BaseActivity): +class ExecuteTask(BaseWorkload): """Execute the given Task.""" ti: TaskInstance diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 92b7c2b0010ed2..b2c8c8220b9934 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -836,7 +836,6 @@ def process_executor_events( ) if info is not None: msg += " Extra info: %s" % info # noqa: RUF100, UP031, flynt - cls.logger().error(msg) session.add(Log(event="state mismatch", extra=msg, task_instance=ti.key)) # Get task from the Serialized DAG @@ -849,6 +848,9 @@ def process_executor_events( continue ti.task = task if task.on_retry_callback or task.on_failure_callback: + # Only log the error/extra info here, since the `ti.handle_failure()` path will log it + # too, which would lead to double logging + cls.logger().error(msg) request = TaskCallbackRequest( full_filepath=ti.dag_model.fileloc, ti=ti, diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 681b18a0d2fef3..9114068e4303f2 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -336,7 +336,7 @@ "celery": { "deps": [ "apache-airflow>=2.9.0", - "celery[redis]>=5.3.0,<6,!=5.3.3,!=5.3.2", + "celery[redis]>=5.4.0,<6", "flower>=1.0.0", "google-re2>=1.0" ], diff --git a/providers/celery/README.rst b/providers/celery/README.rst index f2177d13b60132..001620b7289162 100644 --- a/providers/celery/README.rst +++ b/providers/celery/README.rst @@ -51,14 +51,14 @@ The package supports the following python versions: 3.9,3.10,3.11,3.12 Requirements ------------ -================== ============================== +================== ================== PIP package Version required -================== ============================== +================== ================== ``apache-airflow`` ``>=2.9.0`` -``celery[redis]`` ``>=5.3.0,!=5.3.2,!=5.3.3,<6`` +``celery[redis]`` ``>=5.4.0,<6`` ``flower`` ``>=1.0.0`` ``google-re2`` ``>=1.0`` -================== ============================== +================== ================== Cross provider package dependencies ----------------------------------- diff --git a/providers/celery/pyproject.toml b/providers/celery/pyproject.toml index 011a3812c4803a..959a219c2ebc34 100644 --- a/providers/celery/pyproject.toml +++ b/providers/celery/pyproject.toml @@ -59,9 +59,7 @@ dependencies = [ # The Celery is known to introduce problems when upgraded to a MAJOR version. Airflow Core # Uses Celery for CeleryExecutor, and we also know that Kubernetes Python client follows SemVer # (https://docs.celeryq.dev/en/stable/contributing.html?highlight=semver#versions). - # Make sure that the limit here is synchronized with [celery] extra in the airflow core - # The 5.3.3/5.3.2 limit comes from https://github.com/celery/celery/issues/8470 - "celery[redis]>=5.3.0,<6,!=5.3.3,!=5.3.2", + "celery[redis]>=5.4.0,<6", "flower>=1.0.0", "google-re2>=1.0", ] diff --git a/providers/celery/src/airflow/providers/celery/cli/celery_command.py b/providers/celery/src/airflow/providers/celery/cli/celery_command.py index aaff91bff226cb..aa0a0ec2ebe577 100644 --- a/providers/celery/src/airflow/providers/celery/cli/celery_command.py +++ b/providers/celery/src/airflow/providers/celery/cli/celery_command.py @@ -154,6 +154,11 @@ def worker(args): # This needs to be imported locally to not trigger Providers Manager initialization from airflow.providers.celery.executors.celery_executor import app as celery_app + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.log import configure_logging + + configure_logging(output=sys.stdout.buffer) + # Disable connection pool so that celery worker does not hold an unnecessary db connection settings.reconfigure_orm(disable_connection_pool=True) if not settings.validate_session(): diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index d9121dcd7ab32a..970921c9081d1e 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -33,7 +33,7 @@ from collections.abc import Sequence from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from deprecated import deprecated @@ -53,7 +53,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor -from airflow.providers.celery.version_compat import AIRFLOW_V_2_8_PLUS +from airflow.providers.celery.version_compat import AIRFLOW_V_2_8_PLUS, AIRFLOW_V_3_0_PLUS from airflow.stats import Stats from airflow.utils.state import TaskInstanceState from celery import states as celery_states @@ -67,14 +67,13 @@ if TYPE_CHECKING: import argparse - from airflow.executors.base_executor import CommandType, TaskTuple + from sqlalchemy.orm import Session + + from airflow.executors import workloads + from airflow.executors.base_executor import TaskTuple from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey - from celery import Task - - # Task instance that is sent over Celery queues - # TaskInstanceKey, Command, queue_name, CallableTask - TaskInstanceInCelery = tuple[TaskInstanceKey, CommandType, Optional[str], Task] + from airflow.providers.celery.executors.celery_executor_utils import TaskInstanceInCelery # PEP562 @@ -228,6 +227,11 @@ class CeleryExecutor(BaseExecutor): supports_ad_hoc_ti_run: bool = True supports_sentry: bool = True + if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: + # In the v3 path, we store workloads, not commands as strings. + # TODO: TaskSDK: move this type change into BaseExecutor + queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] + def __init__(self): super().__init__() @@ -256,10 +260,22 @@ def _num_tasks_per_send_process(self, to_send_count: int) -> int: return max(1, math.ceil(to_send_count / self._sync_parallelism)) def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: + # Airflow V2 version from airflow.providers.celery.executors.celery_executor_utils import execute_command task_tuples_to_send = [task_tuple[:3] + (execute_command,) for task_tuple in task_tuples] - first_task = next(t[3] for t in task_tuples_to_send) + + self._send_tasks(task_tuples_to_send) + + def _process_workloads(self, workloads: list[workloads.All]) -> None: + # Airflow V3 version + from airflow.providers.celery.executors.celery_executor_utils import execute_workload + + tasks = [(workload.ti.key, workload, workload.ti.queue, execute_workload) for workload in workloads] + self._send_tasks(tasks) + + def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): + first_task = next(t[-1] for t in task_tuples_to_send) # Celery state queries will stuck if we do not use one same backend # for all tasks. @@ -280,7 +296,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: "[Try %s of %s] Task Timeout Error for Task: (%s).", self.task_publish_retries[key] + 1, self.task_publish_max_retries, - key, + tuple(key), ) self.task_publish_retries[key] = retries + 1 continue @@ -299,7 +315,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: # which point we don't need the ID anymore anyway self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) - def _send_tasks_to_celery(self, task_tuples_to_send: list[TaskInstanceInCelery]): + def _send_tasks_to_celery(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]): from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1: @@ -359,7 +375,7 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None self.success(key, info) elif state in (celery_states.FAILURE, celery_states.REVOKED): self.fail(key, info) - elif state in (celery_states.STARTED, celery_states.PENDING): + elif state in (celery_states.STARTED, celery_states.PENDING, celery_states.RETRY): pass else: self.log.info("Unexpected state for %s: %s", key, state) @@ -416,6 +432,10 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task for celery_task_id, (state, info) in states_by_celery_task_id.items(): result, ti = celery_tasks[celery_task_id] result.backend = cached_celery_backend + if isinstance(result.result, BaseException): + e = result.result + # Log the exception we got from the remote end + self.log.warning("Task %s failed with error", ti.key, exc_info=e) # Set the correct elements of the state dicts, then update this # like we just queried it. @@ -475,6 +495,10 @@ def get_cli_commands() -> list[GroupCommand]: ), ] + def queue_workload(self, workload: workloads.ExecuteTask, session: Session | None) -> None: + ti = workload.ti + self.queued_tasks[ti.key] = workload + def _get_parser() -> argparse.ArgumentParser: """ diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 6d88d9f578d241..38b26ebd79b26d 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -31,7 +31,7 @@ import warnings from collections.abc import Mapping, MutableMapping from concurrent.futures import ProcessPoolExecutor -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from setproctitle import setproctitle from sqlalchemy import select @@ -40,8 +40,8 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor +from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS from airflow.stats import Stats -from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -51,14 +51,25 @@ from celery.backends.database import DatabaseBackend, Task as TaskDb, retry, session_cleanup from celery.signals import import_modules as celery_import_modules +try: + from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager +except ImportError: + from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager + log = logging.getLogger(__name__) if TYPE_CHECKING: + from airflow.executors import workloads from airflow.executors.base_executor import CommandType, EventBufferValueType from airflow.models.taskinstance import TaskInstanceKey + from airflow.typing_compat import TypeAlias from celery.result import AsyncResult - TaskInstanceInCelery = tuple[TaskInstanceKey, CommandType, Optional[str], Task] + # We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so unfortunately we just have to define + # the type as the union of both kinds + TaskInstanceInCelery: TypeAlias = tuple[ + TaskInstanceKey, Union[workloads.All, CommandType], Optional[str], Task + ] OPERATION_TIMEOUT = conf.getfloat("celery", "operation_timeout") @@ -125,21 +136,54 @@ def on_celery_import_modules(*args, **kwargs): import kubernetes.client # noqa: F401 -@app.task -def execute_command(command_to_exec: CommandType) -> None: - """Execute command.""" - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec) +# Once Celery 5.5 is out of beta, we can pass `pydantic=True` to the decorator and it will handle the validation +# and deserialization for us +@app.task(name="execute_workload") +def execute_workload(input: str) -> None: + from pydantic import TypeAdapter + + from airflow.configuration import conf + from airflow.executors import workloads + from airflow.sdk.execution_time.supervisor import supervise + + decoder = TypeAdapter(workloads.All) + workload = decoder.validate_json(input) + celery_task_id = app.current_task.request.id - log.info("[%s] Executing command in Celery: %s", celery_task_id, command_to_exec) - with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): - try: - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - _execute_in_subprocess(command_to_exec, celery_task_id) - else: - _execute_in_fork(command_to_exec, celery_task_id) - except Exception: - Stats.incr("celery.execute_command.failure") - raise + + if not isinstance(workload, workloads.ExecuteTask): + raise ValueError(f"CeleryExecutor does not now how to handle {type(workload)}") + + log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload) + + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, + token=workload.token, + server=conf.get("workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"), + log_path=workload.log_path, + ) + + +if not AIRFLOW_V_3_0_PLUS: + + @app.task + def execute_command(command_to_exec: CommandType) -> None: + """Execute command.""" + dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec) + celery_task_id = app.current_task.request.id + log.info("[%s] Executing command in Celery: %s", celery_task_id, command_to_exec) + with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): + try: + if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: + _execute_in_subprocess(command_to_exec, celery_task_id) + else: + _execute_in_fork(command_to_exec, celery_task_id) + except Exception: + Stats.incr("celery.execute_command.failure") + raise def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None = None) -> None: @@ -213,15 +257,19 @@ def send_task_to_executor( task_tuple: TaskInstanceInCelery, ) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]: """Send task to executor.""" - key, command, queue, task_to_run = task_tuple + from airflow.executors import workloads + + key, args, queue, task_to_run = task_tuple + if isinstance(args, workloads.BaseWorkload): + args = (args.model_dump_json(),) try: with timeout(seconds=OPERATION_TIMEOUT): - result = task_to_run.apply_async(args=[command], queue=queue) + result = task_to_run.apply_async(args=args, queue=queue) except (Exception, AirflowTaskTimeout) as e: exception_traceback = f"Celery Task ID: {key}\n{traceback.format_exc()}" result = ExceptionWithTraceback(e, exception_traceback) - return key, command, result + return key, args, result def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]: diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 680bcfb3d603e2..9caf0a5866890e 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -97,9 +97,9 @@ def _task_event_logs(self, value): def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: """Return queued tasks from celery and kubernetes executor.""" queued_tasks = self.celery_executor.queued_tasks.copy() - queued_tasks.update(self.kubernetes_executor.queued_tasks) + queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] - return queued_tasks + return queued_tasks # type: ignore[return-value] @queued_tasks.setter def queued_tasks(self, value) -> None: diff --git a/providers/celery/src/airflow/providers/celery/executors/default_celery.py b/providers/celery/src/airflow/providers/celery/executors/default_celery.py index 20c307a77b04fb..9fb4a7e3bbbb6f 100644 --- a/providers/celery/src/airflow/providers/celery/executors/default_celery.py +++ b/providers/celery/src/airflow/providers/celery/executors/default_celery.py @@ -27,6 +27,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS def _broker_supports_visibility_timeout(url): @@ -67,7 +68,7 @@ def _broker_supports_visibility_timeout(url): result_backend = conf.get_mandatory_value("celery", "RESULT_BACKEND") else: log.debug("Value for celery result_backend not found. Using sql_alchemy_conn with db+ prefix.") - result_backend = f'db+{conf.get("database", "SQL_ALCHEMY_CONN")}' + result_backend = f"db+{conf.get('database', 'SQL_ALCHEMY_CONN')}" extra_celery_config = conf.getjson("celery", "extra_celery_config", fallback={}) @@ -81,6 +82,9 @@ def _broker_supports_visibility_timeout(url): "task_track_started": conf.getboolean("celery", "task_track_started", fallback=True), "broker_url": broker_url, "broker_transport_options": broker_transport_options, + "broker_connection_retry_on_startup": conf.getboolean( + "celery", "broker_connection_retry_on_startup", fallback=True + ), "result_backend": result_backend, "database_engine_options": conf.getjson( "celery", "result_backend_sqlalchemy_engine_options", fallback={} @@ -90,6 +94,11 @@ def _broker_supports_visibility_timeout(url): **(extra_celery_config if isinstance(extra_celery_config, dict) else {}), } +# In order to not change anything pre Task Execution API, we leave this setting as it was (unset) in Airflow2 +if AIRFLOW_V_3_0_PLUS: + DEFAULT_CELERY_CONFIG.setdefault("worker_redirect_stdouts", False) + DEFAULT_CELERY_CONFIG.setdefault("worker_hijack_root_logger", False) + def _get_celery_ssl_active() -> bool: try: @@ -126,9 +135,7 @@ def _get_celery_ssl_active() -> bool: DEFAULT_CELERY_CONFIG["broker_use_ssl"] = broker_use_ssl except AirflowConfigException: raise AirflowException( - "AirflowConfigException: SSL_ACTIVE is True, " - "please ensure SSL_KEY, " - "SSL_CERT and SSL_CACERT are set" + "AirflowConfigException: SSL_ACTIVE is True, please ensure SSL_KEY, SSL_CERT and SSL_CACERT are set" ) except Exception as e: raise AirflowException( diff --git a/providers/celery/src/airflow/providers/celery/get_provider_info.py b/providers/celery/src/airflow/providers/celery/get_provider_info.py index 2c0a3d70a58775..0cc999f5baf051 100644 --- a/providers/celery/src/airflow/providers/celery/get_provider_info.py +++ b/providers/celery/src/airflow/providers/celery/get_provider_info.py @@ -304,7 +304,7 @@ def get_provider_info(): }, "dependencies": [ "apache-airflow>=2.9.0", - "celery[redis]>=5.3.0,<6,!=5.3.3,!=5.3.2", + "celery[redis]>=5.4.0,<6", "flower>=1.0.0", "google-re2>=1.0", ], diff --git a/providers/celery/tests/provider_tests/celery/executors/test_celery_executor.py b/providers/celery/tests/provider_tests/celery/executors/test_celery_executor.py index 7a33e0cfbc17ca..22dbd59a914c13 100644 --- a/providers/celery/tests/provider_tests/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/provider_tests/celery/executors/test_celery_executor.py @@ -44,7 +44,7 @@ from tests_common.test_utils import db from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS pytestmark = pytest.mark.db_test @@ -71,21 +71,24 @@ def task_id(self): def _prepare_app(broker_url=None, execute=None): broker_url = broker_url or conf.get("celery", "BROKER_URL") - execute = execute or celery_executor_utils.execute_command.__wrapped__ + if AIRFLOW_V_3_0_PLUS: + execute_name = "execute_workload" + execute = execute or celery_executor_utils.execute_workload.__wrapped__ + else: + execute_name = "execute_command" + execute = execute or celery_executor_utils.execute_command.__wrapped__ test_config = dict(celery_executor_utils.celery_configuration) test_config.update({"broker_url": broker_url}) test_app = Celery(broker_url, config_source=test_config) test_execute = test_app.task(execute) - patch_app = mock.patch("airflow.providers.celery.executors.celery_executor_utils.app", test_app) - patch_execute = mock.patch( - "airflow.providers.celery.executors.celery_executor_utils.execute_command", test_execute - ) + patch_app = mock.patch.object(celery_executor_utils, "app", test_app) + patch_execute = mock.patch.object(celery_executor_utils, execute_name, test_execute) backend = test_app.backend if hasattr(backend, "ResultSession"): - # Pre-create the database tables now, otherwise SQLA vis Celery has a + # Pre-create the database tables now, otherwise SQLA via Celery has a # race condition where it one of the subprocesses can die with "Table # already exists" error, because SQLA checks for which tables exist, # then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT @@ -147,6 +150,7 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock ] mock_stats_gauge.assert_has_calls(calls) + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 3 doesn't have execute_command anymore") @pytest.mark.parametrize( "command, raise_exception", [ diff --git a/providers/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py b/providers/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py index f7226fbed71585..4a4d1f9fb6d45d 100644 --- a/providers/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py +++ b/providers/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py @@ -56,7 +56,7 @@ ) from airflow.utils.state import State -CommandType = list[str] +CommandType = Sequence[str] ExecutorConfigType = dict[str, Any] INVALID_CREDENTIALS_EXCEPTIONS = [ @@ -350,7 +350,7 @@ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, self.pending_jobs.append( BatchQueuedJob( key=key, - command=command, + command=list(command), queue=queue, executor_config=executor_config or {}, attempt_number=1, diff --git a/providers/src/airflow/providers/amazon/aws/executors/ecs/utils.py b/providers/src/airflow/providers/amazon/aws/executors/ecs/utils.py index 39b266253ce055..8024e6181db458 100644 --- a/providers/src/airflow/providers/amazon/aws/executors/ecs/utils.py +++ b/providers/src/airflow/providers/amazon/aws/executors/ecs/utils.py @@ -25,6 +25,7 @@ import datetime from collections import defaultdict +from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable @@ -36,7 +37,7 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey -CommandType = list[str] +CommandType = Sequence[str] ExecutorConfigFunctionType = Callable[[CommandType], dict] ExecutorConfigType = dict[str, Any] diff --git a/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index 702703b2142e02..15fa954439a9db 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -411,7 +411,7 @@ def run_next(self, next_job: KubernetesJobType) -> None: map_index=map_index, date=None, run_id=run_id, - args=command, + args=list(command), pod_override_object=kube_executor_config, base_worker_pod=base_worker_pod, with_mutation_hook=True, diff --git a/providers/tests/amazon/aws/executors/batch/test_batch_executor.py b/providers/tests/amazon/aws/executors/batch/test_batch_executor.py index 3b02d11250125e..809187cc9cf1a2 100644 --- a/providers/tests/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/tests/amazon/aws/executors/batch/test_batch_executor.py @@ -189,7 +189,7 @@ class TestAwsBatchExecutor: def test_execute(self, mock_executor): """Test execution from end-to-end""" airflow_key = mock.Mock(spec=tuple) - airflow_cmd = mock.Mock(spec=list) + airflow_cmd = ["1", "2"] mock_executor.batch.submit_job.return_value = {"jobId": MOCK_JOB_ID, "jobName": "some-job-name"} @@ -209,8 +209,8 @@ def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor): failed jobs are added back to the pending_jobs queue to be run in the next iteration. """ airflow_key = TaskInstanceKey("a", "b", "c", 1, -1) - airflow_cmd1 = mock.Mock(spec=list) - airflow_cmd2 = mock.Mock(spec=list) + airflow_cmd1 = ["1", "2"] + airflow_cmd2 = ["3", "4"] airflow_commands = [airflow_cmd1, airflow_cmd2] responses = [Exception("Failure 1"), {"jobId": "job-2"}] @@ -238,8 +238,8 @@ def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor): assert len(mock_executor.active_workers.get_all_jobs()) == 1 # Add more tasks to pending_jobs. This simulates tasks being scheduled by Airflow - airflow_cmd3 = mock.Mock(spec=list) - airflow_cmd4 = mock.Mock(spec=list) + airflow_cmd3 = ["5", "6"] + airflow_cmd4 = ["7", "8"] airflow_commands.extend([airflow_cmd1, airflow_cmd3, airflow_cmd4]) responses.extend([Exception("Failure 1"), {"jobId": "job-3"}, {"jobId": "job-4"}]) mock_executor.execute_async(airflow_key, airflow_cmd3) @@ -277,8 +277,8 @@ def test_attempt_all_jobs_when_jobs_fail(self, _, mock_executor): until all the tasks have been attempted the maximum number of times. """ airflow_key = TaskInstanceKey("a", "b", "c", 1, -1) - airflow_cmd1 = mock.Mock(spec=list) - airflow_cmd2 = mock.Mock(spec=list) + airflow_cmd1 = ["1", "2"] + airflow_cmd2 = ["3", "4"] commands = [airflow_cmd1, airflow_cmd2] failures = [Exception("Failure 1"), Exception("Failure 2")] submit_job_args = { @@ -339,7 +339,7 @@ def test_attempt_submit_jobs_failure(self, mock_executor): def test_task_retry_on_api_failure(self, _, mock_executor, caplog): """Test API failure retries""" airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"] - airflow_cmds = [mock.Mock(spec=list), mock.Mock(spec=list)] + airflow_cmds = [["1", "2"], ["3", "4"]] mock_executor.execute_async(airflow_keys[0], airflow_cmds[0]) mock_executor.execute_async(airflow_keys[1], airflow_cmds[1]) diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 30050c0b955191..3a65247e0ee40a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -166,38 +166,27 @@ def _configure_logs_over_json_channel(log_fd: int): from airflow.sdk.log import configure_logging log_io = os.fdopen(log_fd, "wb", buffering=0) - configure_logging(enable_pretty_log=False, output=log_io) + configure_logging(enable_pretty_log=False, output=log_io, sending_to_supervisor=True) def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): - if "PYTEST_CURRENT_TEST" in os.environ: - # When we are running in pytest, it's output capturing messes us up. This works around it - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, sock, mode in ( - ("stdin", child_stdin, "r"), - ("stdout", child_stdout, "w"), - ("stderr", child_stderr, "w"), + for handle_name, fd, sock, mode in ( + ("stdin", 0, child_stdin, "r"), + ("stdout", 1, child_stdout, "w"), + ("stderr", 2, child_stderr, "w"), ): handle = getattr(sys, handle_name) - try: - fd = handle.fileno() - os.dup2(sock.fileno(), fd) - # dup2 creates another open copy of the fd, we can close the "socket" copy of it. - sock.close() - except io.UnsupportedOperation: - if "PYTEST_CURRENT_TEST" in os.environ: - # When we're running under pytest, the stdin is not a real filehandle with an fd, so we need - # to handle that differently - fd = sock.fileno() - else: - raise - # We can't open text mode fully unbuffered (python throws an exception if we try), but we can make it line buffered with `buffering=1` - handle = os.fdopen(fd, mode, buffering=1) + handle.close() + os.dup2(sock.fileno(), fd) + del sock + + # We open the socket/fd as binary, and then pass it to a TextIOWrapper so that it looks more like a + # normal sys.stdout etc. + binary = os.fdopen(fd, mode + "b") + handle = io.TextIOWrapper(binary, line_buffering=True) setattr(sys, handle_name, handle) @@ -352,8 +341,19 @@ def start( del constructor_kwargs del logger - # Run the child entrypoint - _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + try: + # Run the child entrypoint + _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + except BaseException as e: + try: + # We can't use log here, as if we except out of _fork_main something _weird_ went on. + print("Exception in _fork_main, exiting with code 124", e, file=sys.stderr) + except BaseException as e: + pass + + # It's really super super important we never exit this block. We are in the forked child, and if we + # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will `_exit()` so we never get here) + os._exit(124) requests_fd = child_comms.fileno() diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py index fa5b113588bf55..8549518e205b22 100644 --- a/task_sdk/src/airflow/sdk/log.py +++ b/task_sdk/src/airflow/sdk/log.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import io import itertools import logging.config import os @@ -39,7 +40,9 @@ ] -def exception_group_tracebacks(format_exception: Callable[[ExcInfo], list[dict[str, Any]]]) -> Processor: +def exception_group_tracebacks( + format_exception: Callable[[ExcInfo], list[dict[str, Any]]], +) -> Processor: # Make mypy happy if not hasattr(__builtins__, "BaseExceptionGroup"): T = TypeVar("T") @@ -178,13 +181,6 @@ def logging_processors( "console": console, } else: - # Imports to suppress showing code from these modules - import contextlib - - import click - import httpcore - import httpx - dict_exc_formatter = structlog.tracebacks.ExceptionDictTransformer( use_rich=False, show_locals=False, suppress=suppress ) @@ -197,9 +193,19 @@ def logging_processors( exc_group_processor = None def json_dumps(msg, default): + # Note: this is likely an "expensive" step, but lets massage the dict order for nice + # viewing of the raw JSON logs. + # Maybe we don't need this once the UI renders the JSON instead of displaying the raw text + msg = { + "timestamp": msg.pop("timestamp"), + "level": msg.pop("level"), + "event": msg.pop("event"), + **msg, + } return msgspec.json.encode(msg, enc_hook=default) def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: + # Stdlib logging doesn't need the re-ordering, it's fine as it is return msgspec.json.encode(event_dict).decode("utf-8") json = structlog.processors.JSONRenderer(serializer=json_dumps) @@ -224,13 +230,11 @@ def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: def configure_logging( enable_pretty_log: bool = True, log_level: str = "DEBUG", - output: BinaryIO | None = None, + output: BinaryIO | TextIO | None = None, cache_logger_on_first_use: bool = True, + sending_to_supervisor: bool = False, ): """Set up struct logging and stdlib logging config.""" - if enable_pretty_log and output is not None: - raise ValueError("output can only be set if enable_pretty_log is not") - lvl = structlog.stdlib.NAME_TO_LEVEL[log_level.lower()] if enable_pretty_log: @@ -263,13 +267,30 @@ def configure_logging( wrapper_class = structlog.make_filtering_bound_logger(lvl) if enable_pretty_log: + if output is not None and not isinstance(output, TextIO): + wrapper = io.TextIOWrapper(output, line_buffering=True) + logger_factory = structlog.WriteLoggerFactory(wrapper) + else: + logger_factory = structlog.WriteLoggerFactory(output) structlog.configure( processors=processors, cache_logger_on_first_use=cache_logger_on_first_use, wrapper_class=wrapper_class, + logger_factory=logger_factory, ) color_formatter.append(named["console"]) else: + if output is not None and "b" not in output.mode: + if not hasattr(output, "buffer"): + raise ValueError( + f"output needed to be a binary stream, but it didn't have a buffer attribute ({output=})" + ) + else: + output = output.buffer + if TYPE_CHECKING: + # Not all binary streams are isinstance of BinaryIO, so we check via looking at `mode` at + # runtime. mypy doesn't grok that though + assert isinstance(output, BinaryIO) structlog.configure( processors=processors, cache_logger_on_first_use=cache_logger_on_first_use, @@ -324,7 +345,7 @@ def configure_logging( "loggers": { # Set Airflow logging to the level requested, but most everything else at "INFO" "": { - "handlers": ["to_supervisor" if output else "default"], + "handlers": ["to_supervisor" if sending_to_supervisor else "default"], "level": "INFO", "propagate": True, }, @@ -413,10 +434,12 @@ def init_log_file(local_relative_path: str) -> Path: from airflow.configuration import conf new_file_permissions = int( - conf.get("logging", "file_task_handler_new_file_permissions", fallback="0o664"), 8 + conf.get("logging", "file_task_handler_new_file_permissions", fallback="0o664"), + 8, ) new_folder_permissions = int( - conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8 + conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), + 8, ) base_log_folder = conf.get("logging", "base_log_folder") diff --git a/tests/integration/executors/test_celery_executor.py b/tests/integration/executors/test_celery_executor.py index 0f9f0b45ae9c1a..e1ece9b387bbaf 100644 --- a/tests/integration/executors/test_celery_executor.py +++ b/tests/integration/executors/test_celery_executor.py @@ -39,14 +39,15 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowTaskTimeout -from airflow.executors import base_executor +from airflow.executors import base_executor, workloads from airflow.models.dag import DAG -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance +from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.standard.operators.bash import BashOperator from airflow.utils.state import State, TaskInstanceState from tests_common.test_utils import db +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS logger = logging.getLogger(__name__) @@ -69,16 +70,19 @@ def _prepare_app(broker_url=None, execute=None): from airflow.providers.celery.executors import celery_executor_utils broker_url = broker_url or conf.get("celery", "BROKER_URL") - execute = execute or celery_executor_utils.execute_command.__wrapped__ + if AIRFLOW_V_3_0_PLUS: + execute_name = "execute_workload" + execute = execute or celery_executor_utils.execute_workload.__wrapped__ + else: + execute_name = "execute_command" + execute = execute or celery_executor_utils.execute_command.__wrapped__ test_config = dict(celery_executor_utils.celery_configuration) test_config.update({"broker_url": broker_url}) test_app = Celery(broker_url, config_source=test_config) test_execute = test_app.task(execute) - patch_app = mock.patch("airflow.providers.celery.executors.celery_executor.app", test_app) - patch_execute = mock.patch( - "airflow.providers.celery.executors.celery_executor_utils.execute_command", test_execute - ) + patch_app = mock.patch.object(celery_executor_utils, "app", test_app) + patch_execute = mock.patch.object(celery_executor_utils, execute_name, test_execute) backend = test_app.backend @@ -136,42 +140,35 @@ def _change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=Non def test_celery_integration(self, broker_url): from airflow.providers.celery.executors import celery_executor, celery_executor_utils - success_command = ["airflow", "tasks", "run", "true", "some_parameter"] - fail_command = ["airflow", "version"] - - def fake_execute_command(command): - if command != success_command: + def fake_execute_workload(command): + if "fail" in command: raise AirflowException("fail") - with _prepare_app(broker_url, execute=fake_execute_command) as app: + with _prepare_app(broker_url, execute=fake_execute_workload) as app: executor = celery_executor.CeleryExecutor() assert executor.tasks == {} executor.start() with start_worker(app=app, logfile=sys.stdout, loglevel="info"): - execute_date = datetime.now() - - task_tuples_to_send = [ - ( - ("success", "fake_simple_ti", execute_date, 0), - success_command, - celery_executor_utils.celery_configuration["task_default_queue"], - celery_executor_utils.execute_command, - ), - ( - ("fail", "fake_simple_ti", execute_date, 0), - fail_command, - celery_executor_utils.celery_configuration["task_default_queue"], - celery_executor_utils.execute_command, - ), + ti = workloads.TaskInstance.model_construct( + task_id="success", + dag_id="id", + run_id="abc", + try_number=0, + priority_weight=1, + queue=celery_executor_utils.celery_configuration["task_default_queue"], + ) + keys = [ + TaskInstanceKey("id", "success", "abc", 0, -1), + TaskInstanceKey("id", "fail", "abc", 0, -1), ] + for w in ( + workloads.ExecuteTask.model_construct(ti=ti), + workloads.ExecuteTask.model_construct(ti=ti.model_copy(update={"task_id": "fail"})), + ): + executor.queue_workload(w, session=None) - # "Enqueue" them. We don't have a real SimpleTaskInstance, so directly edit the dict - for key, command, queue, _ in task_tuples_to_send: - executor.queued_tasks[key] = (command, 1, queue, None) - executor.task_publish_retries[key] = 1 - - executor._process_tasks(task_tuples_to_send) + executor.trigger_tasks(open_slots=10) for _ in range(20): num_tasks = len(executor.tasks.keys()) if num_tasks == 2: @@ -181,54 +178,47 @@ def fake_execute_command(command): num_tasks, ) sleep(0.4) - assert list(executor.tasks.keys()) == [ - ("success", "fake_simple_ti", execute_date, 0), - ("fail", "fake_simple_ti", execute_date, 0), - ] - assert ( - executor.event_buffer[("success", "fake_simple_ti", execute_date, 0)][0] == State.QUEUED - ) - assert executor.event_buffer[("fail", "fake_simple_ti", execute_date, 0)][0] == State.QUEUED + assert list(executor.tasks.keys()) == keys + assert executor.event_buffer[keys[0]][0] == State.QUEUED + assert executor.event_buffer[keys[1]][0] == State.QUEUED executor.end(synchronous=True) - assert executor.event_buffer[("success", "fake_simple_ti", execute_date, 0)][0] == State.SUCCESS - assert executor.event_buffer[("fail", "fake_simple_ti", execute_date, 0)][0] == State.FAILED + assert executor.event_buffer[keys[0]][0] == State.SUCCESS + assert executor.event_buffer[keys[1]][0] == State.FAILED - assert "success" not in executor.tasks - assert "fail" not in executor.tasks + assert keys[0] not in executor.tasks + assert keys[1] not in executor.tasks assert executor.queued_tasks == {} def test_error_sending_task(self): from airflow.providers.celery.executors import celery_executor - def fake_execute_command(): + def fake_task(): pass - with _prepare_app(execute=fake_execute_command): - # fake_execute_command takes no arguments while execute_command takes 1, + with _prepare_app(execute=fake_task): + # fake_execute_command takes no arguments while execute_workload takes 1, # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() task = BashOperator( task_id="test", bash_command="true", - dag=DAG(dag_id="id", schedule=None), + dag=DAG(dag_id="dag_id"), start_date=datetime.now(), ) - when = datetime.now() - value_tuple = ( - "command", - 1, - None, - SimpleTaskInstance.from_ti(ti=TaskInstance(task=task, run_id=None)), + ti = TaskInstance(task=task, run_id="abc") + workload = workloads.ExecuteTask.model_construct( + ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), ) - key = ("fail", "fake_simple_ti", when, 0) - executor.queued_tasks[key] = value_tuple + + key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1) + executor.queued_tasks[key] = workload executor.task_publish_retries[key] = 1 executor.heartbeat() assert len(executor.queued_tasks) == 0, "Task should no longer be queued" - assert executor.event_buffer[("fail", "fake_simple_ti", when, 0)][0] == State.FAILED + assert executor.event_buffer[key][0] == State.FAILED def test_retry_on_error_sending_task(self, caplog): """Test that Airflow retries publishing tasks to Celery Broker at least 3 times""" @@ -251,18 +241,16 @@ def test_retry_on_error_sending_task(self, caplog): task = BashOperator( task_id="test", bash_command="true", - dag=DAG(dag_id="id", schedule=None), + dag=DAG(dag_id="id"), start_date=datetime.now(), ) - when = datetime.now() - value_tuple = ( - "command", - 1, - None, - SimpleTaskInstance.from_ti(ti=TaskInstance(task=task, run_id=None)), + ti = TaskInstance(task=task, run_id="abc") + workload = workloads.ExecuteTask.model_construct( + ti=workloads.TaskInstance.model_validate(ti, from_attributes=True), ) - key = ("fail", "fake_simple_ti", when, 0) - executor.queued_tasks[key] = value_tuple + + key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1) + executor.queued_tasks[key] = workload # Test that when heartbeat is called again, task is published again to Celery Queue executor.heartbeat() @@ -286,7 +274,7 @@ def test_retry_on_error_sending_task(self, caplog): executor.heartbeat() assert dict(executor.task_publish_retries) == {} assert len(executor.queued_tasks) == 0, "Task should no longer be in queue" - assert executor.event_buffer[("fail", "fake_simple_ti", when, 0)][0] == State.FAILED + assert executor.event_buffer[key][0] == State.FAILED class ClassWithCustomAttributes: