Skip to content

Commit 2220efc

Browse files
authored
if request id was not supplied in header, generate uuid (#1490)
* if request id was not supplied in header, generate uuid Signed-off-by: Nir Rozenbaum <[email protected]> * convert map to sync.map in plugin state Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 0aa64e5 commit 2220efc

File tree

5 files changed

+88
-23
lines changed

5 files changed

+88
-23
lines changed

pkg/epp/handlers/server.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2727
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
2828
"github.com/go-logr/logr"
29+
"github.com/google/uuid"
2930
"google.golang.org/grpc/codes"
3031
"google.golang.org/grpc/status"
3132
"sigs.k8s.io/controller-runtime/pkg/log"
@@ -186,11 +187,18 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
186187

187188
switch v := req.Request.(type) {
188189
case *extProcPb.ProcessingRequest_RequestHeaders:
189-
if requestID := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestID) > 0 {
190-
logger = logger.WithValues(requtil.RequestIdHeaderKey, requestID)
191-
loggerTrace = logger.V(logutil.TRACE)
192-
ctx = log.IntoContext(ctx, logger)
190+
requestID := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey)
191+
// request ID is a must for maintaining a state per request in plugins that hold internal state and use PluginState.
192+
// if request id was not supplied as a header, we generate it ourselves.
193+
if len(requestID) == 0 {
194+
requestID = uuid.NewString()
195+
loggerTrace.Info("RequestID header is not found in the request, generated a request id")
196+
reqCtx.Request.Headers[requtil.RequestIdHeaderKey] = requestID // update in headers so director can consume it
193197
}
198+
logger = logger.WithValues(requtil.RequestIdHeaderKey, requestID)
199+
loggerTrace = logger.V(logutil.TRACE)
200+
ctx = log.IntoContext(ctx, logger)
201+
194202
err = s.HandleRequestHeaders(reqCtx, v)
195203
case *extProcPb.ProcessingRequest_RequestBody:
196204
loggerTrace.Info("Incoming body chunk", "EoS", v.RequestBody.EndOfStream)

pkg/epp/plugins/plugin_state.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func NewPluginState(ctx context.Context) *PluginState {
5151
// Note: PluginState uses a sync.Map to back the storage, because it is thread safe.
5252
// It's aimed to optimize for the "write once and read many times" scenarios.
5353
type PluginState struct {
54-
// key: RequestID, value: map[StateKey]StateData
54+
// key: RequestID, value: sync.Map[StateKey]StateData
5555
storage sync.Map
5656
// key: RequestID, value: time.Time
5757
requestToLastAccessTime sync.Map
@@ -66,9 +66,9 @@ func (s *PluginState) Read(requestID string, key StateKey) (StateData, error) {
6666
return nil, ErrNotFound
6767
}
6868

69-
stateData := stateMap.(map[StateKey]StateData)
70-
if value, ok := stateData[key]; ok {
71-
return value, nil
69+
stateData := stateMap.(*sync.Map)
70+
if value, ok := stateData.Load(key); ok {
71+
return value.(StateData), nil
7272
}
7373

7474
return nil, ErrNotFound
@@ -77,15 +77,15 @@ func (s *PluginState) Read(requestID string, key StateKey) (StateData, error) {
7777
// Write stores the given "val" in PluginState with the given "key" in the context of the given "requestID".
7878
func (s *PluginState) Write(requestID string, key StateKey, val StateData) {
7979
s.requestToLastAccessTime.Store(requestID, time.Now())
80-
var stateData map[StateKey]StateData
80+
var stateData *sync.Map
8181
stateMap, ok := s.storage.Load(requestID)
8282
if ok {
83-
stateData = stateMap.(map[StateKey]StateData)
83+
stateData = stateMap.(*sync.Map)
8484
} else {
85-
stateData = map[StateKey]StateData{}
85+
stateData = &sync.Map{}
8686
}
8787

88-
stateData[key] = val
88+
stateData.Store(key, val)
8989

9090
s.storage.Store(requestID, stateData)
9191
}

pkg/epp/server/server_test.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,10 @@ const (
4242
)
4343

4444
func TestServer(t *testing.T) {
45-
theHeaderValue := "body"
46-
requestHeader := "x-test"
47-
4845
expectedRequestHeaders := map[string]string{metadata.DestinationEndpointKey: fmt.Sprintf("%s:%d", podAddress, poolPort),
49-
"Content-Length": "42", ":method": "POST", requestHeader: theHeaderValue}
50-
expectedResponseHeaders := map[string]string{"x-went-into-resp-headers": "true", ":method": "POST", requestHeader: theHeaderValue}
51-
expectedSchedulerHeaders := map[string]string{":method": "POST", requestHeader: theHeaderValue}
46+
"Content-Length": "42", ":method": "POST", "x-test": "body", "x-request-id": "test-request-id"}
47+
expectedResponseHeaders := map[string]string{"x-went-into-resp-headers": "true", ":method": "POST", "x-test": "body"}
48+
expectedSchedulerHeaders := map[string]string{":method": "POST", "x-test": "body", "x-request-id": "test-request-id"}
5249

5350
t.Run("server", func(t *testing.T) {
5451
model := testutil.MakeInferenceObjective("v1").
@@ -66,9 +63,10 @@ func TestServer(t *testing.T) {
6663

6764
// Send request headers - no response expected
6865
headers := utils.BuildEnvoyGRPCHeaders(map[string]string{
69-
requestHeader: theHeaderValue,
66+
"x-test": "body",
7067
":method": "POST",
7168
metadata.FlowFairnessIDKey: "a-very-interesting-fairness-id",
69+
"x-request-id": "test-request-id",
7270
}, true)
7371
request := &pb.ProcessingRequest{
7472
Request: &pb.ProcessingRequest_RequestHeaders{
@@ -130,9 +128,6 @@ func TestServer(t *testing.T) {
130128
}
131129

132130
// Check headers passed to the scheduler
133-
if len(director.requestHeaders) != 2 {
134-
t.Errorf("Incorrect number of request headers %d instead of 2", len(director.requestHeaders))
135-
}
136131
for expectedKey, expectedValue := range expectedSchedulerHeaders {
137132
got, ok := director.requestHeaders[expectedKey]
138133
if !ok {
@@ -143,7 +138,7 @@ func TestServer(t *testing.T) {
143138
}
144139

145140
// Send response headers
146-
headers = utils.BuildEnvoyGRPCHeaders(map[string]string{requestHeader: theHeaderValue, ":method": "POST"}, false)
141+
headers = utils.BuildEnvoyGRPCHeaders(map[string]string{"x-test": "body", ":method": "POST"}, false)
147142
request = &pb.ProcessingRequest{
148143
Request: &pb.ProcessingRequest_ResponseHeaders{
149144
ResponseHeaders: headers,

test/integration/epp/hermetic_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ import (
7676
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer"
7777
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server"
7878
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
79+
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
7980
epptestutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
8081
integrationutils "sigs.k8s.io/gateway-api-inference-extension/test/integration"
8182
)
@@ -187,6 +188,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
187188
RawValue: []byte("mom"),
188189
},
189190
},
191+
&configPb.HeaderValueOption{
192+
Header: &configPb.HeaderValue{
193+
Key: requtil.RequestIdHeaderKey,
194+
RawValue: []byte("test-request-id"),
195+
},
196+
},
190197
),
191198
},
192199
{
@@ -250,6 +257,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
250257
RawValue: []byte("mom"),
251258
},
252259
},
260+
&configPb.HeaderValueOption{
261+
Header: &configPb.HeaderValue{
262+
Key: requtil.RequestIdHeaderKey,
263+
RawValue: []byte("test-request-id"),
264+
},
265+
},
253266
),
254267
},
255268
{
@@ -279,6 +292,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
279292
RawValue: []byte("mom"),
280293
},
281294
},
295+
&configPb.HeaderValueOption{
296+
Header: &configPb.HeaderValue{
297+
Key: requtil.RequestIdHeaderKey,
298+
RawValue: []byte("test-request-id"),
299+
},
300+
},
282301
),
283302
},
284303
{
@@ -308,6 +327,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
308327
RawValue: []byte("mom"),
309328
},
310329
},
330+
&configPb.HeaderValueOption{
331+
Header: &configPb.HeaderValue{
332+
Key: requtil.RequestIdHeaderKey,
333+
RawValue: []byte("test-request-id"),
334+
},
335+
},
311336
),
312337
},
313338
{
@@ -330,6 +355,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
330355
Key: metadata.ModelNameRewriteKey,
331356
Value: modelSheddableTarget,
332357
},
358+
{
359+
Key: requtil.RequestIdHeaderKey,
360+
Value: "test-request-id",
361+
},
333362
},
334363
},
335364
},
@@ -368,6 +397,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
368397
RawValue: []byte("mom"),
369398
},
370399
},
400+
&configPb.HeaderValueOption{
401+
Header: &configPb.HeaderValue{
402+
Key: requtil.RequestIdHeaderKey,
403+
RawValue: []byte("test-request-id"),
404+
},
405+
},
371406
),
372407
},
373408
{
@@ -394,6 +429,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
394429
Key: metadata.ModelNameRewriteKey,
395430
Value: modelDirect,
396431
},
432+
{
433+
Key: requtil.RequestIdHeaderKey,
434+
Value: "test-request-id",
435+
},
397436
},
398437
},
399438
},
@@ -432,6 +471,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
432471
RawValue: []byte("mom"),
433472
},
434473
},
474+
&configPb.HeaderValueOption{
475+
Header: &configPb.HeaderValue{
476+
Key: requtil.RequestIdHeaderKey,
477+
RawValue: []byte("test-request-id"),
478+
},
479+
},
435480
),
436481
},
437482
// Response flow tests
@@ -778,6 +823,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
778823
RawValue: []byte("mom"),
779824
},
780825
},
826+
&configPb.HeaderValueOption{
827+
Header: &configPb.HeaderValue{
828+
Key: requtil.RequestIdHeaderKey,
829+
RawValue: []byte("test-request-id"),
830+
},
831+
},
781832
),
782833
},
783834
{
@@ -811,6 +862,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) {
811862
RawValue: []byte("mom"),
812863
},
813864
},
865+
&configPb.HeaderValueOption{
866+
Header: &configPb.HeaderValue{
867+
Key: requtil.RequestIdHeaderKey,
868+
RawValue: []byte("test-request-id"),
869+
},
870+
},
814871
),
815872
},
816873
{

test/integration/util.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131

3232
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
3333
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
34+
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3435
)
3536

3637
const (
@@ -130,6 +131,10 @@ func GenerateStreamedRequestSet(logger logr.Logger, prompt, model, targetModel s
130131
Key: metadata.ModelNameRewriteKey,
131132
Value: targetModel,
132133
},
134+
{
135+
Key: requtil.RequestIdHeaderKey,
136+
Value: "test-request-id",
137+
},
133138
},
134139
},
135140
},

0 commit comments

Comments
 (0)