Skip to content
18 changes: 15 additions & 3 deletions ray-operator/controllers/ray/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ func SafeUint64ToInt64(n uint64) int64 {
return int64(n)
}

// SafeInt64ToInt32 converts int64 to int32, preventing overflow/underflow by
// bounding the value between [math.MinInt32, math.MaxInt32]
func SafeInt64ToInt32(n int64) int32 {
if n > math.MaxInt32 {
return math.MaxInt32
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may need a minimum check to avoid lint error. Also, could you follow the naming convention for the function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I'll add the minimum check and see if it resolves the lint error, and rename the function to SafeInt64ToInt32 as well.

if n < math.MinInt32 {
return math.MinInt32
}
return int32(n)
}

// GetNamespace return namespace
func GetNamespace(metaData metav1.ObjectMeta) string {
if metaData.Namespace == "" {
Expand Down Expand Up @@ -393,15 +405,15 @@ func CalculateMinReplicas(cluster *rayv1.RayCluster) int32 {

// CalculateMaxReplicas calculates max worker replicas at the cluster level
func CalculateMaxReplicas(cluster *rayv1.RayCluster) int32 {
count := int32(0)
count := int64(0)
for _, nodeGroup := range cluster.Spec.WorkerGroupSpecs {
if nodeGroup.Suspend != nil && *nodeGroup.Suspend {
continue
}
count += (*nodeGroup.MaxReplicas * nodeGroup.NumOfHosts)
count += int64(*nodeGroup.MaxReplicas) * int64(nodeGroup.NumOfHosts)
}

return count
return SafeInt64ToInt32(count)
}

// CalculateReadyReplicas calculates ready worker replicas at the cluster level
Expand Down
94 changes: 94 additions & 0 deletions ray-operator/controllers/ray/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,100 @@ func TestCalculateDesiredReplicas(t *testing.T) {
}
}

func TestCalculateMaxReplicasOverflow(t *testing.T) {
tests := []struct {
name string
specs []rayv1.WorkerGroupSpec
expected int32
}{
{
name: "Bug reproduction: issue report with replicas=1, minReplicas=3, numOfHosts=4",
specs: []rayv1.WorkerGroupSpec{
{
GroupName: "workergroup",
Replicas: ptr.To[int32](1),
MinReplicas: ptr.To[int32](3),
MaxReplicas: ptr.To[int32](2147483647), // Default max int32
NumOfHosts: 4,
},
},
expected: 2147483647, // Was -4 before fix, should be capped at max int32
},
{
name: "Single group overflow with default maxReplicas and numOfHosts=4",
specs: []rayv1.WorkerGroupSpec{
{
NumOfHosts: 4,
MinReplicas: ptr.To[int32](3),
MaxReplicas: ptr.To[int32](2147483647),
},
},
expected: 2147483647, // Should be capped at max int32
},
{
name: "Single group overflow with large values",
specs: []rayv1.WorkerGroupSpec{
{
NumOfHosts: 1000,
MinReplicas: ptr.To[int32](1),
MaxReplicas: ptr.To[int32](2147483647),
},
},
expected: 2147483647, // Should be capped
},
{
name: "Multiple groups causing overflow when summed",
specs: []rayv1.WorkerGroupSpec{
{
NumOfHosts: 2,
MinReplicas: ptr.To[int32](1),
MaxReplicas: ptr.To[int32](1500000000),
},
{
NumOfHosts: 1,
MinReplicas: ptr.To[int32](1),
MaxReplicas: ptr.To[int32](1000000000),
},
},
expected: 2147483647, // 3B + 1B > max int32, should be capped
},
{
name: "No overflow with reasonable values",
specs: []rayv1.WorkerGroupSpec{
{
NumOfHosts: 4,
MinReplicas: ptr.To[int32](2),
MaxReplicas: ptr.To[int32](100),
},
},
expected: 400, // 100 * 4 = 400, no overflow
},
{
name: "Edge case: exactly at max int32",
specs: []rayv1.WorkerGroupSpec{
{
NumOfHosts: 1,
MinReplicas: ptr.To[int32](1),
MaxReplicas: ptr.To[int32](2147483647),
},
},
expected: 2147483647, // Exactly at limit
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster := &rayv1.RayCluster{
Spec: rayv1.RayClusterSpec{
WorkerGroupSpecs: tc.specs,
},
}
result := CalculateMaxReplicas(cluster)
assert.Equal(t, tc.expected, result)
})
}
}

func TestUnmarshalRuntimeEnv(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading