Skip to content

Commit 3d4a636

Browse files
authored
prefix state (#1509)
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 2220efc commit 3d4a636

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ func (p *Plugin) WithName(name string) *Plugin {
174174
}
175175

176176
// Score returns the scoring result for the given list of pods based on context.
177-
func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
177+
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
178178
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
179179
// pre score step, hashing prompt and find longest prefix match.
180180
hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch)
@@ -183,7 +183,8 @@ func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.
183183
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
184184
}
185185

186-
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state)
186+
cycleState.Write(plugins.StateKey(p.TypedName().String()), state)
187+
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
187188
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
188189
// calculate the scores of pods
189190
scores := make(map[types.Pod]float64, len(pods))
@@ -208,7 +209,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
208209
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
209210
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
210211

211-
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType)
212+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
212213
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
213214
if err != nil {
214215
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ func TestPrefixPlugin(t *testing.T) {
5252
TargetModel: "test-model1",
5353
Prompt: "aaaaaa",
5454
}
55-
scores := plugin.Score(context.Background(), nil, req1, pods)
56-
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType)
55+
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
56+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
5757
assert.NoError(t, err)
5858
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
5959
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -79,8 +79,8 @@ func TestPrefixPlugin(t *testing.T) {
7979
TargetModel: "test-model2",
8080
Prompt: "bbbbbb",
8181
}
82-
scores = plugin.Score(context.Background(), nil, req2, pods)
83-
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, PrefixCachePluginType)
82+
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
83+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
8484
assert.NoError(t, err)
8585
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
8686
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -105,8 +105,8 @@ func TestPrefixPlugin(t *testing.T) {
105105
TargetModel: "test-model1",
106106
Prompt: "aaaabbbb",
107107
}
108-
scores = plugin.Score(context.Background(), nil, req3, pods)
109-
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, PrefixCachePluginType)
108+
scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods)
109+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String()))
110110
assert.NoError(t, err)
111111
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
112112
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -130,8 +130,8 @@ func TestPrefixPlugin(t *testing.T) {
130130
TargetModel: "test-model-new",
131131
Prompt: "aaaabbbb",
132132
}
133-
scores = plugin.Score(context.Background(), nil, req4, pods)
134-
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, PrefixCachePluginType)
133+
scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods)
134+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String()))
135135
assert.NoError(t, err)
136136
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
137137
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -155,8 +155,8 @@ func TestPrefixPlugin(t *testing.T) {
155155
TargetModel: "test-model1",
156156
Prompt: "aaaabbbbcccc",
157157
}
158-
scores = plugin.Score(context.Background(), nil, req5, pods)
159-
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, PrefixCachePluginType)
158+
scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods)
159+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String()))
160160
assert.NoError(t, err)
161161
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
162162
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
@@ -212,7 +212,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
212212
}
213213

214214
// First cycle: simulate scheduling and insert prefix info into the cache
215-
plugin.Score(context.Background(), nil, req, pods)
215+
plugin.Score(context.Background(), types.NewCycleState(), req, pods)
216216
schedulingResult := &types.SchedulingResult{
217217
PrimaryProfileName: "default",
218218
ProfileResults: map[string]*types.ProfileRunResult{
@@ -222,7 +222,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
222222
plugin.PreRequest(context.Background(), req, schedulingResult, 0)
223223

224224
// Second cycle: validate internal state
225-
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, PrefixCachePluginType)
225+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
226226
assert.NoError(b, err)
227227
expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Prompt)/blockSize)))
228228
assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")

0 commit comments

Comments
 (0)