Skip to content

Commit bd0aff8

Browse files
Add to ValidatingWebhook to check for SecurityContext in head and worker containers.
1 parent cbbee9d commit bd0aff8

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

pkg/controllers/raycluster_webhook.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package controllers
1818

1919
import (
2020
"context"
21+
"reflect"
2122
"strconv"
2223

2324
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
@@ -133,6 +134,7 @@ func (w *rayClusterWebhook) ValidateCreate(ctx context.Context, obj runtime.Obje
133134
var allErrors field.ErrorList
134135

135136
allErrors = append(allErrors, validateIngress(rayCluster)...)
137+
allErrors = append(allErrors, validateSecurityContext(rayCluster)...)
136138

137139
if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
138140
allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...)
@@ -155,6 +157,7 @@ func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj r
155157
}
156158

157159
allErrors = append(allErrors, validateIngress(rayCluster)...)
160+
allErrors = append(allErrors, validateSecurityContext(rayCluster)...)
158161

159162
if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
160163
allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...)
@@ -202,6 +205,32 @@ func validateOAuthProxyVolume(rayCluster *rayv1.RayCluster) field.ErrorList {
202205
return allErrors
203206
}
204207

208+
func validateSecurityContext(rayCluster *rayv1.RayCluster) field.ErrorList {
209+
var allErrors field.ErrorList
210+
211+
for i := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
212+
if !reflect.DeepEqual(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext, securityContext()) {
213+
allErrors = append(allErrors, field.Invalid(
214+
field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(i), "securityContext"),
215+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext,
216+
"SecurityContext is immutable"))
217+
}
218+
}
219+
220+
for i := range rayCluster.Spec.WorkerGroupSpecs {
221+
for j := range rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers {
222+
if !reflect.DeepEqual(rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext, securityContext()) {
223+
allErrors = append(allErrors, field.Invalid(
224+
field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "containers", strconv.Itoa(j), "securityContext"),
225+
rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext,
226+
"SecurityContext is immutable"))
227+
}
228+
}
229+
}
230+
231+
return allErrors
232+
}
233+
205234
func validateIngress(rayCluster *rayv1.RayCluster) field.ErrorList {
206235
var allErrors field.ErrorList
207236

@@ -268,6 +297,18 @@ func oauthProxyContainer(rayCluster *rayv1.RayCluster) corev1.Container {
268297
}
269298
}
270299

300+
func securityContext() *corev1.SecurityContext {
301+
return &corev1.SecurityContext{
302+
AllowPrivilegeEscalation: ptr.To(false),
303+
Capabilities: &corev1.Capabilities{
304+
Drop: []corev1.Capability{"ALL"},
305+
},
306+
SeccompProfile: &corev1.SeccompProfile{
307+
Type: "RuntimeDefault",
308+
},
309+
}
310+
}
311+
271312
func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume {
272313
return corev1.Volume{
273314
Name: oauthProxyVolumeName,

pkg/controllers/raycluster_webhook_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
corev1 "k8s.io/api/core/v1"
2727
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2828
"k8s.io/apimachinery/pkg/runtime"
29+
"k8s.io/utils/ptr"
2930

3031
"github.com/project-codeflare/codeflare-operator/pkg/config"
3132
)
@@ -277,6 +278,15 @@ func TestValidateCreate(t *testing.T) {
277278
ReadOnly: true,
278279
},
279280
},
281+
SecurityContext: &corev1.SecurityContext{
282+
AllowPrivilegeEscalation: ptr.To(false),
283+
Capabilities: &corev1.Capabilities{
284+
Drop: []corev1.Capability{"ALL"},
285+
},
286+
SeccompProfile: &corev1.SeccompProfile{
287+
Type: "RuntimeDefault",
288+
},
289+
},
280290
},
281291
},
282292
Volumes: []corev1.Volume{
@@ -346,6 +356,14 @@ func TestValidateCreate(t *testing.T) {
346356
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group service account name")
347357
})
348358

359+
t.Run("Negative: Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext", func(t *testing.T) {
360+
for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
361+
invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true)
362+
}
363+
_, err = rcWebhook.ValidateCreate(test.Ctx(), runtime.Object(invalidRayCluster))
364+
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext")
365+
})
366+
349367
}
350368

351369
func TestValidateUpdate(t *testing.T) {
@@ -409,6 +427,15 @@ func TestValidateUpdate(t *testing.T) {
409427
ReadOnly: true,
410428
},
411429
},
430+
SecurityContext: &corev1.SecurityContext{
431+
AllowPrivilegeEscalation: ptr.To(false),
432+
Capabilities: &corev1.Capabilities{
433+
Drop: []corev1.Capability{"ALL"},
434+
},
435+
SeccompProfile: &corev1.SeccompProfile{
436+
Type: "RuntimeDefault",
437+
},
438+
},
412439
},
413440
},
414441
InitContainers: []corev1.Container{
@@ -485,6 +512,15 @@ func TestValidateUpdate(t *testing.T) {
485512
{Name: "RAY_TLS_SERVER_KEY", Value: "/home/ray/workspace/tls/server.key"},
486513
{Name: "RAY_TLS_CA_CERT", Value: "/home/ray/workspace/tls/ca.crt"},
487514
},
515+
SecurityContext: &corev1.SecurityContext{
516+
AllowPrivilegeEscalation: ptr.To(false),
517+
Capabilities: &corev1.Capabilities{
518+
Drop: []corev1.Capability{"ALL"},
519+
},
520+
SeccompProfile: &corev1.SeccompProfile{
521+
Type: "RuntimeDefault",
522+
},
523+
},
488524
},
489525
},
490526
InitContainers: []corev1.Container{
@@ -644,4 +680,22 @@ func TestValidateUpdate(t *testing.T) {
644680
_, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster))
645681
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated env vars in the worker group")
646682
})
683+
684+
t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container", func(t *testing.T) {
685+
for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
686+
invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true)
687+
}
688+
_, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster))
689+
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container")
690+
})
691+
692+
t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container", func(t *testing.T) {
693+
for i := range invalidRayCluster.Spec.WorkerGroupSpecs {
694+
for j := range invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers {
695+
invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext.AllowPrivilegeEscalation = ptr.To(true)
696+
}
697+
}
698+
_, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster))
699+
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container")
700+
})
647701
}

0 commit comments

Comments
 (0)