Skip to content

Commit

Permalink
Expand KPO template_fields, fix Spark k8s operator tests (apache#46268)
Browse files Browse the repository at this point in the history
* Add name and hostname to KPO template_fields

* Add fetch container mock to name normalization test

* Fix bugged tests

* Run execute in test_pod_name to get validation fail

* Add long name case to KPO unit tests
  • Loading branch information
insomnes authored and ambika-garg committed Feb 13, 2025
1 parent f9a658c commit a19e58f
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 178 deletions.
26 changes: 15 additions & 11 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,18 +1054,22 @@ def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock):

def test_pod_name(self, mock_get_connection):
pod_name_too_long = "a" * 221
k = KubernetesPodOperator(
namespace="default",
image="ubuntu:16.04",
cmds=["bash", "-cx"],
arguments=["echo 10"],
labels=self.labels,
name=pod_name_too_long,
task_id=str(uuid4()),
in_cluster=False,
do_xcom_push=False,
)
# Name is now in template fields, and it's final value requires context
# so we need to execute for name validation
context = create_context(k)
with pytest.raises(AirflowException):
KubernetesPodOperator(
namespace="default",
image="ubuntu:16.04",
cmds=["bash", "-cx"],
arguments=["echo 10"],
labels=self.labels,
name=pod_name_too_long,
task_id=str(uuid4()),
in_cluster=False,
do_xcom_push=False,
)
k.execute(context)

def test_on_kill(self, mock_get_connection):
hook = KubernetesHook(conn_id=None, in_cluster=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ class KubernetesPodOperator(BaseOperator):

template_fields: Sequence[str] = (
"image",
"name",
"hostname",
"cmds",
"annotations",
"arguments",
Expand Down Expand Up @@ -391,7 +393,7 @@ def __init__(
self.priority_class_name = priority_class_name
self.pod_template_file = pod_template_file
self.pod_template_dict = pod_template_dict
self.name = self._set_name(name)
self.name = name
self.random_name_suffix = random_name_suffix
self.termination_grace_period = termination_grace_period
self.pod_request_obj: k8s.V1Pod | None = None
Expand Down Expand Up @@ -587,6 +589,7 @@ def extract_xcom(self, pod: k8s.V1Pod) -> dict[Any, Any] | None:

def execute(self, context: Context):
"""Based on the deferrable parameter runs the pod asynchronously or synchronously."""
self.name = self._set_name(self.name)
if not self.deferrable:
return self.execute_sync(context)

Expand Down
38 changes: 38 additions & 0 deletions providers/tests/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def test_templates(self, create_task_instance_of_operator, session):
session=session,
dag_id=dag_id,
task_id="task-id",
name="{{ dag.dag_id }}",
hostname="{{ dag.dag_id }}",
namespace="{{ dag.dag_id }}",
container_resources=k8s.V1ResourceRequirements(
requests={"memory": "{{ dag.dag_id }}", "cpu": "{{ dag.dag_id }}"},
Expand Down Expand Up @@ -189,6 +191,8 @@ def test_templates(self, create_task_instance_of_operator, session):
assert dag_id == rendered.volume_mounts[0].sub_path
assert dag_id == ti.task.image
assert dag_id == ti.task.cmds
assert dag_id == ti.task.name
assert dag_id == ti.task.hostname
assert dag_id == ti.task.namespace
assert dag_id == ti.task.config_file
assert dag_id == ti.task.labels
Expand Down Expand Up @@ -1150,6 +1154,40 @@ def test_no_handle_failure_on_success(self, fetch_container_mock):
# assert does not raise
self.run_pod(k)

@pytest.mark.parametrize("randomize", [True, False])
@patch(f"{POD_MANAGER_CLASS}.await_container_completion", new=MagicMock)
@patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs")
def test_name_normalized_on_execution(self, fetch_container_mock, randomize):
name_base = "test_extra-123"
normalized_name = "test-extra-123"

k = KubernetesPodOperator(
name=name_base,
random_name_suffix=randomize,
task_id="task",
get_logs=False,
)

pod, _ = self.run_pod(k)
if randomize:
# To avoid
assert isinstance(pod.metadata.name, str)
assert pod.metadata.name.startswith(normalized_name)
assert k.name.startswith(normalized_name)
else:
assert pod.metadata.name == normalized_name
assert k.name == normalized_name

@pytest.mark.parametrize("name", ["name@extra", "a" * 300], ids=["bad", "long"])
def test_name_validation_on_execution(self, name):
k = KubernetesPodOperator(
name=name,
task_id="task",
)

with pytest.raises(AirflowException):
self.run_pod(k)

def test_create_with_affinity(self):
affinity = {
"nodeAffinity": {
Expand Down
Loading

0 comments on commit a19e58f

Please sign in to comment.