diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py index 3c4b31893698c0..e1e7e85bcdab62 100644 --- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -20,6 +20,7 @@ import copy import json from datetime import date +from functools import cached_property from unittest import mock from unittest.mock import patch from uuid import uuid4 @@ -203,11 +204,10 @@ def create_context(task): @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_start") @patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.client") -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.create_job_name") @patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object_status") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") -class TestSparkKubernetesOperator: +class TestSparkKubernetesOperatorCreateApplication: def setup_method(self): db.merge_conn( Connection(conn_id="kubernetes_default_kube_config", conn_type="kubernetes", extra=json.dumps({})) @@ -222,268 +222,270 @@ def setup_method(self): args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} self.dag = DAG("test_dag_id", schedule=None, default_args=args) - def execute_operator(self, task_name, mock_create_job_name, job_spec): - mock_create_job_name.return_value = task_name - op = SparkKubernetesOperator( - template_spec=job_spec, - kubernetes_conn_id="kubernetes_default_kube_config", - task_id=task_name, - get_logs=True, - ) - context = create_context(op) - op.execute(context) - return op - - def test_create_application_from_yaml_json( + def execute_operator( self, - mock_create_namespaced_crd, - mock_get_namespaced_custom_object_status, - mock_cleanup, - mock_create_job_name, - mock_get_kube_client, - mock_create_pod, - mock_await_pod_start, - mock_await_pod_completion, - mock_fetch_requested_container_logs, - data_file, + task_name, + *, + name=None, + job_spec=None, + application_file=None, + random_name_suffix=False, ): - task_name = "default_yaml" - mock_create_job_name.return_value = task_name op = SparkKubernetesOperator( - application_file=data_file("spark/application_test.yaml").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", + name=name, task_id=task_name, - ) - context = create_context(op) - op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = task_name - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - ) - - task_name = "default_json" - mock_create_job_name.return_value = task_name - op = SparkKubernetesOperator( - application_file=data_file("spark/application_test.json").as_posix(), + random_name_suffix=random_name_suffix, + application_file=application_file, + template_spec=job_spec, kubernetes_conn_id="kubernetes_default_kube_config", - task_id=task_name, ) context = create_context(op) op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = task_name - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - ) + return op + + @cached_property + def call_commons(self): + return { + "group": "sparkoperator.k8s.io", + "namespace": "default", + "plural": "sparkapplications", + "version": "v1beta2", + } - def test_create_application_from_yaml_json_and_use_name_from_metadata( + @pytest.mark.parametrize( + "task_name, application_file_path", + [ + ("default_yaml", "spark/application_test.yaml"), + ("default_json", "spark/application_test.json"), + ], + ) + @pytest.mark.parametrize("random_name_suffix", [True, False]) + def test_create_application( self, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, - mock_create_job_name, mock_get_kube_client, mock_create_pod, mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, + task_name, + application_file_path, + random_name_suffix, ): - op = SparkKubernetesOperator( - application_file=data_file("spark/application_test.yaml").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="create_app_and_use_name_from_metadata", - ) - context = create_context(op) - op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = op.name - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", + done_op = self.execute_operator( + task_name=task_name, + application_file=data_file(application_file_path).as_posix(), + random_name_suffix=random_name_suffix, ) - assert op.name.startswith("default_yaml") + assert done_op.task_id == task_name + # The name generation is out of scope of this test, so we don't set any + # expectations on the name. We just check that the name is set. + assert isinstance(done_op.name, str) + assert done_op.name != "" - op = SparkKubernetesOperator( - application_file=data_file("spark/application_test.json").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="create_app_and_use_name_from_metadata", - ) - context = create_context(op) - op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = op.name + TEST_APPLICATION_DICT["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", + **self.call_commons, ) - assert op.name.startswith("default_json") - def test_create_application_from_yaml_json_and_use_name_from_operator_args( + @pytest.mark.parametrize( + "task_name, application_file_path", + [ + ("default_yaml", "spark/application_test.yaml"), + ("default_json", "spark/application_test.json"), + ], + ) + @pytest.mark.parametrize("random_name_suffix", [True, False]) + def test_create_application_and_use_name_from_operator_args( self, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, - mock_create_job_name, mock_get_kube_client, mock_create_pod, mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, + task_name, + application_file_path, + random_name_suffix, ): - op = SparkKubernetesOperator( - application_file=data_file("spark/application_test.yaml").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="default_yaml", - name="test-spark", + name = f"test-name-{task_name}" + done_op = self.execute_operator( + task_name=task_name, + name=name, + application_file=data_file(application_file_path).as_posix(), + random_name_suffix=random_name_suffix, ) - context = create_context(op) - op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = op.name - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - ) - assert op.name.startswith("test-spark") - op = SparkKubernetesOperator( - application_file=data_file("spark/application_test.json").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="default_json", - name="test-spark", - ) - context = create_context(op) - op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = op.name + name_normalized = name.replace("_", "-") + assert done_op.task_id == task_name + assert isinstance(done_op.name, str) + if random_name_suffix: + assert done_op.name.startswith(name_normalized) + else: + assert done_op.name == name_normalized + + TEST_APPLICATION_DICT["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", + **self.call_commons, ) - assert op.name.startswith("test-spark") - def test_create_application_from_yaml_json_and_use_name_task_id( + @pytest.mark.parametrize( + "task_name, application_file_path", + [ + ("task_id_yml", "spark/application_test_with_no_name_from_config.yaml"), + ("task_id_json", "spark/application_test_with_no_name_from_config.json"), + ], + ) + @pytest.mark.parametrize("random_name_suffix", [True, False]) + def test_create_application_and_use_name_task_id( self, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, - mock_create_job_name, mock_get_kube_client, mock_create_pod, mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, + task_name, + application_file_path, + random_name_suffix, ): - op = SparkKubernetesOperator( - application_file=data_file("spark/application_test_with_no_name_from_config.yaml").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="create_app_and_use_name_from_task_id", - ) - context = create_context(op) - op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = op.name - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", + done_op = self.execute_operator( + task_name=task_name, + application_file=data_file(application_file_path).as_posix(), + random_name_suffix=random_name_suffix, ) - assert op.name.startswith("create_app_and_use_name_from_task_id") - op = SparkKubernetesOperator( - application_file=data_file("spark/application_test_with_no_name_from_config.json").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="create_app_and_use_name_from_task_id", - ) - context = create_context(op) - op.execute(context) - TEST_APPLICATION_DICT["metadata"]["name"] = op.name + name_normalized = task_name.replace("_", "-") + assert isinstance(done_op.name, str) + if random_name_suffix: + assert done_op.name.startswith(name_normalized) + else: + assert done_op.name == name_normalized + + TEST_APPLICATION_DICT["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", + **self.call_commons, ) - assert op.name.startswith("create_app_and_use_name_from_task_id") + @pytest.mark.parametrize("random_name_suffix", [True, False]) def test_new_template_from_yaml( self, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, - mock_create_job_name, mock_get_kube_client, mock_create_pod, mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, + random_name_suffix, ): task_name = "default_yaml_template" - mock_create_job_name.return_value = task_name - op = SparkKubernetesOperator( + done_op = self.execute_operator( + task_name=task_name, application_file=data_file("spark/application_template.yaml").as_posix(), - kubernetes_conn_id="kubernetes_default_kube_config", - task_id=task_name, + random_name_suffix=random_name_suffix, ) - context = create_context(op) - op.execute(context) - TEST_K8S_DICT["metadata"]["name"] = task_name + + name_normalized = task_name.replace("_", "-") + assert isinstance(done_op.name, str) + if random_name_suffix: + assert done_op.name.startswith(name_normalized) + else: + assert done_op.name == name_normalized + + TEST_K8S_DICT["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=TEST_K8S_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", + **self.call_commons, ) + @pytest.mark.parametrize("random_name_suffix", [True, False]) def test_template_spec( self, mock_create_namespaced_crd, mock_get_namespaced_custom_object_status, mock_cleanup, - mock_create_job_name, mock_get_kube_client, mock_create_pod, mock_await_pod_start, mock_await_pod_completion, mock_fetch_requested_container_logs, data_file, + random_name_suffix, ): task_name = "default_yaml_template" - job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) - self.execute_operator(task_name, mock_create_job_name, job_spec=job_spec) + done_op = self.execute_operator( + task_name=task_name, + job_spec=job_spec, + random_name_suffix=random_name_suffix, + ) + + name_normalized = task_name.replace("_", "-") + assert isinstance(done_op.name, str) + if random_name_suffix: + assert done_op.name.startswith(name_normalized) + else: + assert done_op.name == name_normalized - TEST_K8S_DICT["metadata"]["name"] = task_name + TEST_K8S_DICT["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=TEST_K8S_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", + **self.call_commons, ) + +@pytest.mark.db_test +@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_requested_container_logs") +@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") +@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_start") +@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.client") +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.create_job_name") +@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") +@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object_status") +@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") +class TestSparkKubernetesOperator: + def setup_method(self): + db.merge_conn( + Connection(conn_id="kubernetes_default_kube_config", conn_type="kubernetes", extra=json.dumps({})) + ) + db.merge_conn( + Connection( + conn_id="kubernetes_with_namespace", + conn_type="kubernetes", + extra=json.dumps({"extra__kubernetes__namespace": "mock_namespace"}), + ) + ) + args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} + self.dag = DAG("test_dag_id", schedule=None, default_args=args) + + def execute_operator(self, task_name, mock_create_job_name, job_spec): + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + ) + context = create_context(op) + op.execute(context) + return op + def test_env( self, mock_create_namespaced_crd,