From bd0aff86004c28501e0db674076beb4fdd2897fe Mon Sep 17 00:00:00 2001 From: ChristianZaccaria Date: Thu, 20 Jun 2024 12:49:17 +0100 Subject: [PATCH 1/2] Add to ValidatingWebhook to check for SecurityContext in head and worker containers. --- pkg/controllers/raycluster_webhook.go | 41 ++++++++++++++++ pkg/controllers/raycluster_webhook_test.go | 54 ++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/pkg/controllers/raycluster_webhook.go b/pkg/controllers/raycluster_webhook.go index dab128115..c827a1ea8 100644 --- a/pkg/controllers/raycluster_webhook.go +++ b/pkg/controllers/raycluster_webhook.go @@ -18,6 +18,7 @@ package controllers import ( "context" + "reflect" "strconv" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" @@ -133,6 +134,7 @@ func (w *rayClusterWebhook) ValidateCreate(ctx context.Context, obj runtime.Obje var allErrors field.ErrorList allErrors = append(allErrors, validateIngress(rayCluster)...) + allErrors = append(allErrors, validateSecurityContext(rayCluster)...) if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) { allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...) @@ -155,6 +157,7 @@ func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj r } allErrors = append(allErrors, validateIngress(rayCluster)...) + allErrors = append(allErrors, validateSecurityContext(rayCluster)...) if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) { allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...) @@ -202,6 +205,32 @@ func validateOAuthProxyVolume(rayCluster *rayv1.RayCluster) field.ErrorList { return allErrors } +func validateSecurityContext(rayCluster *rayv1.RayCluster) field.ErrorList { + var allErrors field.ErrorList + + for i := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + if !reflect.DeepEqual(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext, securityContext()) { + allErrors = append(allErrors, field.Invalid( + field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(i), "securityContext"), + rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext, + "SecurityContext is immutable")) + } + } + + for i := range rayCluster.Spec.WorkerGroupSpecs { + for j := range rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers { + if !reflect.DeepEqual(rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext, securityContext()) { + allErrors = append(allErrors, field.Invalid( + field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "containers", strconv.Itoa(j), "securityContext"), + rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext, + "SecurityContext is immutable")) + } + } + } + + return allErrors +} + func validateIngress(rayCluster *rayv1.RayCluster) field.ErrorList { var allErrors field.ErrorList @@ -268,6 +297,18 @@ func oauthProxyContainer(rayCluster *rayv1.RayCluster) corev1.Container { } } +func securityContext() *corev1.SecurityContext { + return &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + } +} + func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume { return corev1.Volume{ Name: oauthProxyVolumeName, diff --git a/pkg/controllers/raycluster_webhook_test.go b/pkg/controllers/raycluster_webhook_test.go index 44927309d..373bc186d 100644 --- a/pkg/controllers/raycluster_webhook_test.go +++ b/pkg/controllers/raycluster_webhook_test.go @@ -26,6 +26,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/utils/ptr" "github.com/project-codeflare/codeflare-operator/pkg/config" ) @@ -277,6 +278,15 @@ func TestValidateCreate(t *testing.T) { ReadOnly: true, }, }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + }, }, }, Volumes: []corev1.Volume{ @@ -346,6 +356,14 @@ func TestValidateCreate(t *testing.T) { test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group service account name") }) + t.Run("Negative: Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext", func(t *testing.T) { + for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true) + } + _, err = rcWebhook.ValidateCreate(test.Ctx(), runtime.Object(invalidRayCluster)) + test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext") + }) + } func TestValidateUpdate(t *testing.T) { @@ -409,6 +427,15 @@ func TestValidateUpdate(t *testing.T) { ReadOnly: true, }, }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + }, }, }, InitContainers: []corev1.Container{ @@ -485,6 +512,15 @@ func TestValidateUpdate(t *testing.T) { {Name: "RAY_TLS_SERVER_KEY", Value: "/home/ray/workspace/tls/server.key"}, {Name: "RAY_TLS_CA_CERT", Value: "/home/ray/workspace/tls/ca.crt"}, }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + }, }, }, InitContainers: []corev1.Container{ @@ -644,4 +680,22 @@ func TestValidateUpdate(t *testing.T) { _, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster)) test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated env vars in the worker group") }) + + t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container", func(t *testing.T) { + for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true) + } + _, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster)) + test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container") + }) + + t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container", func(t *testing.T) { + for i := range invalidRayCluster.Spec.WorkerGroupSpecs { + for j := range invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers { + invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext.AllowPrivilegeEscalation = ptr.To(true) + } + } + _, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster)) + test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container") + }) } From 20d8abf0ac817c214b77c6e9c2ae8b4133a41ba7 Mon Sep 17 00:00:00 2001 From: ChristianZaccaria Date: Thu, 20 Jun 2024 15:42:03 +0100 Subject: [PATCH 2/2] Add to MutatingWebhook to add SecurityContext to the head and worker containers --- pkg/controllers/raycluster_webhook.go | 10 ++++++++++ pkg/controllers/raycluster_webhook_test.go | 16 ++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/pkg/controllers/raycluster_webhook.go b/pkg/controllers/raycluster_webhook.go index c827a1ea8..989fc242a 100644 --- a/pkg/controllers/raycluster_webhook.go +++ b/pkg/controllers/raycluster_webhook.go @@ -124,6 +124,16 @@ func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) err } } + // Set the security context for the head container and worker containers + for i := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext = securityContext() + } + for i := range rayCluster.Spec.WorkerGroupSpecs { + for j := range rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers { + rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext = securityContext() + } + } + return nil } diff --git a/pkg/controllers/raycluster_webhook_test.go b/pkg/controllers/raycluster_webhook_test.go index 373bc186d..d3cb379f4 100644 --- a/pkg/controllers/raycluster_webhook_test.go +++ b/pkg/controllers/raycluster_webhook_test.go @@ -227,6 +227,22 @@ func TestRayClusterWebhookDefault(t *testing.T) { } }) + t.Run("Expected required SecurityContext for each head group container", func(t *testing.T) { + for _, container := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + test.Expect(container.SecurityContext).To(Equal(securityContext()), + "Expected the required SecurityContext to be present in each head group container") + } + }) + + t.Run("Expected required SecurityContext for each worker group container", func(t *testing.T) { + for _, workerGroup := range validRayCluster.Spec.WorkerGroupSpecs { + for _, container := range workerGroup.Template.Spec.Containers { + test.Expect(container.SecurityContext).To(Equal(securityContext()), + "Expected the required SecurityContext to be present in each worker group container") + } + } + }) + } func TestValidateCreate(t *testing.T) {