Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ func (r *Runner) registerInTreePlugins() {
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory)
plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory)
plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory)
plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory)
plugins.Register(scorer.KvCacheUtilizationScorerType, scorer.KvCacheUtilizationScorerFactory)
plugins.Register(scorer.QueueScorerType, scorer.QueueScorerFactory)
Expand Down
1 change: 1 addition & 0 deletions pkg/epp/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ func registerNeededPlgugins() {
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory)
plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory)
plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory)
plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory)
}

Expand Down
117 changes: 117 additions & 0 deletions pkg/epp/scheduling/framework/plugins/picker/picker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,120 @@ func TestPickMaxScorePicker(t *testing.T) {
})
}
}

func TestPickWeightedRandomPicker(t *testing.T) {
const (
testIterations = 1000
tolerance = 0.2 // 20% tolerance in [0,1] range
)

pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}
pod4 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}}
pod5 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}}

// A-Res algorithm uses U^(1/w) transformation which introduces statistical variance
// beyond simple proportional sampling. Generous tolerance is required to prevent
// flaky tests in CI environments, especially for multi-tier weights.
tests := []struct {
name string
input []*types.ScoredPod
maxPods int // maxNumOfEndpoints for this test
}{
{
name: "High weight dominance test",
input: []*types.ScoredPod{
{Pod: pod1, Score: 10}, // Lower weight
{Pod: pod2, Score: 90}, // Higher weight (should dominate)
},
maxPods: 1,
},
{
name: "Equal weights test - A-Res uniform distribution",
input: []*types.ScoredPod{
{Pod: pod1, Score: 100}, // Equal weights (higher values for better numerical precision)
{Pod: pod2, Score: 100}, // Equal weights should yield uniform distribution
{Pod: pod3, Score: 100}, // Equal weights in A-Res
},
maxPods: 1,
},
{
name: "Zero weight exclusion test - A-Res edge case",
input: []*types.ScoredPod{
{Pod: pod1, Score: 30}, // Normal weight, should be selected
{Pod: pod2, Score: 0}, // Zero weight, never selected in A-Res
},
maxPods: 1,
},
{
name: "Multi-tier weighted test - A-Res complex distribution",
input: []*types.ScoredPod{
{Pod: pod1, Score: 100}, // Highest weight
{Pod: pod2, Score: 90}, // High weight
{Pod: pod3, Score: 50}, // Medium weight
{Pod: pod4, Score: 30}, // Low weight
{Pod: pod5, Score: 20}, // Lowest weight
},
maxPods: 1,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
picker := NewWeightedRandomPicker(test.maxPods)
selectionCounts := make(map[string]int)

// Calculate expected probabilities based on scores
totalScore := 0.0
for _, pod := range test.input {
totalScore += pod.Score
}

expectedProbabilities := make(map[string]float64)
for _, pod := range test.input {
podName := pod.GetPod().NamespacedName.Name
if totalScore > 0 {
expectedProbabilities[podName] = pod.Score / totalScore
} else {
expectedProbabilities[podName] = 0.0
}
}

// Initialize selection counters for each pod
for _, pod := range test.input {
podName := pod.GetPod().NamespacedName.Name
selectionCounts[podName] = 0
}

// Run multiple iterations to gather statistical data
for i := 0; i < testIterations; i++ {
result := picker.Pick(context.Background(), types.NewCycleState(), test.input)

// Count selections for probability analysis
if len(result.TargetPods) > 0 {
selectedPodName := result.TargetPods[0].GetPod().NamespacedName.Name
selectionCounts[selectedPodName]++
}
}

// Verify probability distribution
for podName, expectedProb := range expectedProbabilities {
actualCount := selectionCounts[podName]
actualProb := float64(actualCount) / float64(testIterations)

toleranceValue := expectedProb * tolerance
lowerBound := expectedProb - toleranceValue
upperBound := expectedProb + toleranceValue

if actualProb < lowerBound || actualProb > upperBound {
t.Errorf("Pod %s: expected probability %.3f ±%.1f%%, got %.3f (count: %d/%d)",
podName, expectedProb, tolerance*100, actualProb, actualCount, testIterations)
} else {
t.Logf("Pod %s: expected %.3f, got %.3f (count: %d/%d) ✓",
podName, expectedProb, actualProb, actualCount, testIterations)
}
}
})
}
}
169 changes: 169 additions & 0 deletions pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package picker

import (
"context"
"encoding/json"
"fmt"
"math"
"math/rand"
"sort"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

const (
WeightedRandomPickerType = "weighted-random-picker"
)

// weightedScoredPod represents a scored pod with its A-Res sampling key
type weightedScoredPod struct {
*types.ScoredPod
key float64
}

var _ framework.Picker = &WeightedRandomPicker{}

func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := pickerParameters{
MaxNumOfEndpoints: DefaultMaxNumOfEndpoints,
}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", WeightedRandomPickerType, err)
}
}

return NewWeightedRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil
}

func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker {
if maxNumOfEndpoints <= 0 {
maxNumOfEndpoints = DefaultMaxNumOfEndpoints
}

return &WeightedRandomPicker{
typedName: plugins.TypedName{Type: WeightedRandomPickerType, Name: WeightedRandomPickerType},
maxNumOfEndpoints: maxNumOfEndpoints,
randomPicker: NewRandomPicker(maxNumOfEndpoints),
}
}

type WeightedRandomPicker struct {
typedName plugins.TypedName
maxNumOfEndpoints int
randomPicker *RandomPicker // fallback for zero weights
}

func (p *WeightedRandomPicker) WithName(name string) *WeightedRandomPicker {
p.typedName.Name = name
return p
}

func (p *WeightedRandomPicker) TypedName() plugins.TypedName {
return p.typedName
}

// WeightedRandomPicker performs weighted random sampling using A-Res algorithm.
// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf
// Algorithm:
// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ)
// - Selects k items with largest keys for mathematically correct weighted sampling
// - More efficient than traditional cumulative probability approach
//
// Key characteristics:
// - Mathematically correct weighted random sampling
// - Single pass algorithm with O(n + k log k) complexity
func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates using weighted random sampling: %+v",
p.maxNumOfEndpoints, len(scoredPods), scoredPods))

// Check if all weights are zero or negative
allZeroWeights := true
for _, scoredPod := range scoredPods {
if scoredPod.Score > 0 {
allZeroWeights = false
break
}
}

// Delegate to RandomPicker for uniform selection when all weights are zero
if allZeroWeights {
log.FromContext(ctx).V(logutil.DEBUG).Info("All weights are zero, delegating to RandomPicker for uniform selection")
return p.randomPicker.Pick(ctx, cycleState, scoredPods)
}

randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))

// A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ)
weightedPods := make([]weightedScoredPod, 0, len(scoredPods))

for _, scoredPod := range scoredPods {
weight := float64(scoredPod.Score)

// Handle zero or negative weights
if weight <= 0 {
// Assign very small key for zero-weight pods (effectively excludes them)
weightedPods = append(weightedPods, weightedScoredPod{
ScoredPod: scoredPod,
key: 0,
})
continue
}

// Generate random number U in (0,1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be updated to// Generate random number U in [0.0, 1.0) here?

It will be clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No 0 is always excluded so the existing comment is correct.

	u := randomGenerator.Float64()
		if u == 0 {
			u = 1e-10 // Avoid log(0)   #<-- this will be like 0.0000000001
		}

u := randomGenerator.Float64()
if u == 0 {
u = 1e-10 // Avoid log(0)
}

// Calculate key = U^(1/weight)
key := math.Pow(u, 1.0/weight)

weightedPods = append(weightedPods, weightedScoredPod{
ScoredPod: scoredPod,
key: key,
})
}

// Sort by key in descending order (largest keys first)
sort.Slice(weightedPods, func(i, j int) bool {
return weightedPods[i].key > weightedPods[j].key
})

// Select top k pods
selectedCount := min(p.maxNumOfEndpoints, len(weightedPods))

scoredPods = make([]*types.ScoredPod, selectedCount)
for i := range selectedCount {
scoredPods[i] = weightedPods[i].ScoredPod
}

targetPods := make([]types.Pod, len(scoredPods))
for i, scoredPod := range scoredPods {
targetPods[i] = scoredPod
}

return &types.ProfileRunResult{TargetPods: targetPods}
}