diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index eb2ea3b25..45b455ab3 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -404,6 +404,7 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory) plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory) plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory) + plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory) plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory) plugins.Register(scorer.KvCacheUtilizationScorerType, scorer.KvCacheUtilizationScorerFactory) plugins.Register(scorer.QueueScorerType, scorer.QueueScorerFactory) diff --git a/pkg/epp/config/loader/configloader_test.go b/pkg/epp/config/loader/configloader_test.go index 0701f9642..260fe3655 100644 --- a/pkg/epp/config/loader/configloader_test.go +++ b/pkg/epp/config/loader/configloader_test.go @@ -443,6 +443,7 @@ func registerNeededPlgugins() { plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory) plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory) plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory) + plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory) plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory) } diff --git a/pkg/epp/scheduling/framework/plugins/picker/picker_test.go b/pkg/epp/scheduling/framework/plugins/picker/picker_test.go index 2c3acebda..741a49d59 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/picker_test.go +++ b/pkg/epp/scheduling/framework/plugins/picker/picker_test.go @@ -135,3 +135,120 @@ func TestPickMaxScorePicker(t *testing.T) { }) } } + +func TestPickWeightedRandomPicker(t *testing.T) { + const ( + testIterations = 1000 + tolerance = 0.2 // 20% tolerance in [0,1] range + ) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}} + pod4 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}} + pod5 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}} + + // A-Res algorithm uses U^(1/w) transformation which introduces statistical variance + // beyond simple proportional sampling. Generous tolerance is required to prevent + // flaky tests in CI environments, especially for multi-tier weights. + tests := []struct { + name string + input []*types.ScoredPod + maxPods int // maxNumOfEndpoints for this test + }{ + { + name: "High weight dominance test", + input: []*types.ScoredPod{ + {Pod: pod1, Score: 10}, // Lower weight + {Pod: pod2, Score: 90}, // Higher weight (should dominate) + }, + maxPods: 1, + }, + { + name: "Equal weights test - A-Res uniform distribution", + input: []*types.ScoredPod{ + {Pod: pod1, Score: 100}, // Equal weights (higher values for better numerical precision) + {Pod: pod2, Score: 100}, // Equal weights should yield uniform distribution + {Pod: pod3, Score: 100}, // Equal weights in A-Res + }, + maxPods: 1, + }, + { + name: "Zero weight exclusion test - A-Res edge case", + input: []*types.ScoredPod{ + {Pod: pod1, Score: 30}, // Normal weight, should be selected + {Pod: pod2, Score: 0}, // Zero weight, never selected in A-Res + }, + maxPods: 1, + }, + { + name: "Multi-tier weighted test - A-Res complex distribution", + input: []*types.ScoredPod{ + {Pod: pod1, Score: 100}, // Highest weight + {Pod: pod2, Score: 90}, // High weight + {Pod: pod3, Score: 50}, // Medium weight + {Pod: pod4, Score: 30}, // Low weight + {Pod: pod5, Score: 20}, // Lowest weight + }, + maxPods: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + picker := NewWeightedRandomPicker(test.maxPods) + selectionCounts := make(map[string]int) + + // Calculate expected probabilities based on scores + totalScore := 0.0 + for _, pod := range test.input { + totalScore += pod.Score + } + + expectedProbabilities := make(map[string]float64) + for _, pod := range test.input { + podName := pod.GetPod().NamespacedName.Name + if totalScore > 0 { + expectedProbabilities[podName] = pod.Score / totalScore + } else { + expectedProbabilities[podName] = 0.0 + } + } + + // Initialize selection counters for each pod + for _, pod := range test.input { + podName := pod.GetPod().NamespacedName.Name + selectionCounts[podName] = 0 + } + + // Run multiple iterations to gather statistical data + for i := 0; i < testIterations; i++ { + result := picker.Pick(context.Background(), types.NewCycleState(), test.input) + + // Count selections for probability analysis + if len(result.TargetPods) > 0 { + selectedPodName := result.TargetPods[0].GetPod().NamespacedName.Name + selectionCounts[selectedPodName]++ + } + } + + // Verify probability distribution + for podName, expectedProb := range expectedProbabilities { + actualCount := selectionCounts[podName] + actualProb := float64(actualCount) / float64(testIterations) + + toleranceValue := expectedProb * tolerance + lowerBound := expectedProb - toleranceValue + upperBound := expectedProb + toleranceValue + + if actualProb < lowerBound || actualProb > upperBound { + t.Errorf("Pod %s: expected probability %.3f ±%.1f%%, got %.3f (count: %d/%d)", + podName, expectedProb, tolerance*100, actualProb, actualCount, testIterations) + } else { + t.Logf("Pod %s: expected %.3f, got %.3f (count: %d/%d) ✓", + podName, expectedProb, actualProb, actualCount, testIterations) + } + } + }) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go new file mode 100644 index 000000000..c12ab72b0 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go @@ -0,0 +1,169 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package picker + +import ( + "context" + "encoding/json" + "fmt" + "math" + "math/rand" + "sort" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + WeightedRandomPickerType = "weighted-random-picker" +) + +// weightedScoredPod represents a scored pod with its A-Res sampling key +type weightedScoredPod struct { + *types.ScoredPod + key float64 +} + +var _ framework.Picker = &WeightedRandomPicker{} + +func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + parameters := pickerParameters{ + MaxNumOfEndpoints: DefaultMaxNumOfEndpoints, + } + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, ¶meters); err != nil { + return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", WeightedRandomPickerType, err) + } + } + + return NewWeightedRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil +} + +func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker { + if maxNumOfEndpoints <= 0 { + maxNumOfEndpoints = DefaultMaxNumOfEndpoints + } + + return &WeightedRandomPicker{ + typedName: plugins.TypedName{Type: WeightedRandomPickerType, Name: WeightedRandomPickerType}, + maxNumOfEndpoints: maxNumOfEndpoints, + randomPicker: NewRandomPicker(maxNumOfEndpoints), + } +} + +type WeightedRandomPicker struct { + typedName plugins.TypedName + maxNumOfEndpoints int + randomPicker *RandomPicker // fallback for zero weights +} + +func (p *WeightedRandomPicker) WithName(name string) *WeightedRandomPicker { + p.typedName.Name = name + return p +} + +func (p *WeightedRandomPicker) TypedName() plugins.TypedName { + return p.typedName +} + +// WeightedRandomPicker performs weighted random sampling using A-Res algorithm. +// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf +// Algorithm: +// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ) +// - Selects k items with largest keys for mathematically correct weighted sampling +// - More efficient than traditional cumulative probability approach +// +// Key characteristics: +// - Mathematically correct weighted random sampling +// - Single pass algorithm with O(n + k log k) complexity +func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { + log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates using weighted random sampling: %+v", + p.maxNumOfEndpoints, len(scoredPods), scoredPods)) + + // Check if all weights are zero or negative + allZeroWeights := true + for _, scoredPod := range scoredPods { + if scoredPod.Score > 0 { + allZeroWeights = false + break + } + } + + // Delegate to RandomPicker for uniform selection when all weights are zero + if allZeroWeights { + log.FromContext(ctx).V(logutil.DEBUG).Info("All weights are zero, delegating to RandomPicker for uniform selection") + return p.randomPicker.Pick(ctx, cycleState, scoredPods) + } + + randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano())) + + // A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ) + weightedPods := make([]weightedScoredPod, 0, len(scoredPods)) + + for _, scoredPod := range scoredPods { + weight := float64(scoredPod.Score) + + // Handle zero or negative weights + if weight <= 0 { + // Assign very small key for zero-weight pods (effectively excludes them) + weightedPods = append(weightedPods, weightedScoredPod{ + ScoredPod: scoredPod, + key: 0, + }) + continue + } + + // Generate random number U in (0,1) + u := randomGenerator.Float64() + if u == 0 { + u = 1e-10 // Avoid log(0) + } + + // Calculate key = U^(1/weight) + key := math.Pow(u, 1.0/weight) + + weightedPods = append(weightedPods, weightedScoredPod{ + ScoredPod: scoredPod, + key: key, + }) + } + + // Sort by key in descending order (largest keys first) + sort.Slice(weightedPods, func(i, j int) bool { + return weightedPods[i].key > weightedPods[j].key + }) + + // Select top k pods + selectedCount := min(p.maxNumOfEndpoints, len(weightedPods)) + + scoredPods = make([]*types.ScoredPod, selectedCount) + for i := range selectedCount { + scoredPods[i] = weightedPods[i].ScoredPod + } + + targetPods := make([]types.Pod, len(scoredPods)) + for i, scoredPod := range scoredPods { + targetPods[i] = scoredPod + } + + return &types.ProfileRunResult{TargetPods: targetPods} +}