diff --git a/pkg/controllers/raycluster_webhook.go b/pkg/controllers/raycluster_webhook.go index dab128115..989fc242a 100644 --- a/pkg/controllers/raycluster_webhook.go +++ b/pkg/controllers/raycluster_webhook.go @@ -18,6 +18,7 @@ package controllers import ( "context" + "reflect" "strconv" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" @@ -123,6 +124,16 @@ func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) err } } + // Set the security context for the head container and worker containers + for i := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext = securityContext() + } + for i := range rayCluster.Spec.WorkerGroupSpecs { + for j := range rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers { + rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext = securityContext() + } + } + return nil } @@ -133,6 +144,7 @@ func (w *rayClusterWebhook) ValidateCreate(ctx context.Context, obj runtime.Obje var allErrors field.ErrorList allErrors = append(allErrors, validateIngress(rayCluster)...) + allErrors = append(allErrors, validateSecurityContext(rayCluster)...) if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) { allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...) @@ -155,6 +167,7 @@ func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj r } allErrors = append(allErrors, validateIngress(rayCluster)...) + allErrors = append(allErrors, validateSecurityContext(rayCluster)...) if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) { allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...) @@ -202,6 +215,32 @@ func validateOAuthProxyVolume(rayCluster *rayv1.RayCluster) field.ErrorList { return allErrors } +func validateSecurityContext(rayCluster *rayv1.RayCluster) field.ErrorList { + var allErrors field.ErrorList + + for i := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + if !reflect.DeepEqual(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext, securityContext()) { + allErrors = append(allErrors, field.Invalid( + field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(i), "securityContext"), + rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext, + "SecurityContext is immutable")) + } + } + + for i := range rayCluster.Spec.WorkerGroupSpecs { + for j := range rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers { + if !reflect.DeepEqual(rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext, securityContext()) { + allErrors = append(allErrors, field.Invalid( + field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "containers", strconv.Itoa(j), "securityContext"), + rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext, + "SecurityContext is immutable")) + } + } + } + + return allErrors +} + func validateIngress(rayCluster *rayv1.RayCluster) field.ErrorList { var allErrors field.ErrorList @@ -268,6 +307,18 @@ func oauthProxyContainer(rayCluster *rayv1.RayCluster) corev1.Container { } } +func securityContext() *corev1.SecurityContext { + return &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + } +} + func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume { return corev1.Volume{ Name: oauthProxyVolumeName, diff --git a/pkg/controllers/raycluster_webhook_test.go b/pkg/controllers/raycluster_webhook_test.go index 44927309d..d3cb379f4 100644 --- a/pkg/controllers/raycluster_webhook_test.go +++ b/pkg/controllers/raycluster_webhook_test.go @@ -26,6 +26,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/utils/ptr" "github.com/project-codeflare/codeflare-operator/pkg/config" ) @@ -226,6 +227,22 @@ func TestRayClusterWebhookDefault(t *testing.T) { } }) + t.Run("Expected required SecurityContext for each head group container", func(t *testing.T) { + for _, container := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + test.Expect(container.SecurityContext).To(Equal(securityContext()), + "Expected the required SecurityContext to be present in each head group container") + } + }) + + t.Run("Expected required SecurityContext for each worker group container", func(t *testing.T) { + for _, workerGroup := range validRayCluster.Spec.WorkerGroupSpecs { + for _, container := range workerGroup.Template.Spec.Containers { + test.Expect(container.SecurityContext).To(Equal(securityContext()), + "Expected the required SecurityContext to be present in each worker group container") + } + } + }) + } func TestValidateCreate(t *testing.T) { @@ -277,6 +294,15 @@ func TestValidateCreate(t *testing.T) { ReadOnly: true, }, }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + }, }, }, Volumes: []corev1.Volume{ @@ -346,6 +372,14 @@ func TestValidateCreate(t *testing.T) { test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group service account name") }) + t.Run("Negative: Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext", func(t *testing.T) { + for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true) + } + _, err = rcWebhook.ValidateCreate(test.Ctx(), runtime.Object(invalidRayCluster)) + test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext") + }) + } func TestValidateUpdate(t *testing.T) { @@ -409,6 +443,15 @@ func TestValidateUpdate(t *testing.T) { ReadOnly: true, }, }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + }, }, }, InitContainers: []corev1.Container{ @@ -485,6 +528,15 @@ func TestValidateUpdate(t *testing.T) { {Name: "RAY_TLS_SERVER_KEY", Value: "/home/ray/workspace/tls/server.key"}, {Name: "RAY_TLS_CA_CERT", Value: "/home/ray/workspace/tls/ca.crt"}, }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + SeccompProfile: &corev1.SeccompProfile{ + Type: "RuntimeDefault", + }, + }, }, }, InitContainers: []corev1.Container{ @@ -644,4 +696,22 @@ func TestValidateUpdate(t *testing.T) { _, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster)) test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated env vars in the worker group") }) + + t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container", func(t *testing.T) { + for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers { + invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true) + } + _, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster)) + test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container") + }) + + t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container", func(t *testing.T) { + for i := range invalidRayCluster.Spec.WorkerGroupSpecs { + for j := range invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers { + invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext.AllowPrivilegeEscalation = ptr.To(true) + } + } + _, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster)) + test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container") + }) }