Skip to content

Commit b135da2

Browse files
MatusKyselfjl
andauthored
rpc: add method name length limit (#31711)
This change adds a limit for RPC method names to prevent potential abuse where large method names could lead to large response sizes. The limit is enforced in: - handleCall for regular RPC method calls - handleSubscribe for subscription method calls Added tests in websocket_test.go to verify the length limit functionality for both regular method calls and subscriptions. --------- Co-authored-by: Felix Lange <[email protected]>
1 parent bca0646 commit b135da2

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

rpc/handler.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,10 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
501501
if msg.isUnsubscribe() {
502502
callb = h.unsubscribeCb
503503
} else {
504+
// Check method name length
505+
if len(msg.Method) > maxMethodNameLength {
506+
return msg.errorResponse(&invalidRequestError{fmt.Sprintf("method name too long: %d > %d", len(msg.Method), maxMethodNameLength)})
507+
}
504508
callb = h.reg.callback(msg.Method)
505509
}
506510
if callb == nil {
@@ -536,6 +540,11 @@ func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMes
536540
return msg.errorResponse(ErrNotificationsUnsupported)
537541
}
538542

543+
// Check method name length
544+
if len(msg.Method) > maxMethodNameLength {
545+
return msg.errorResponse(&invalidRequestError{fmt.Sprintf("subscription name too long: %d > %d", len(msg.Method), maxMethodNameLength)})
546+
}
547+
539548
// Subscription method name is first argument.
540549
name, err := parseSubscriptionName(msg.Params)
541550
if err != nil {

rpc/json.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ const (
3535
subscribeMethodSuffix = "_subscribe"
3636
unsubscribeMethodSuffix = "_unsubscribe"
3737
notificationMethodSuffix = "_subscription"
38+
maxMethodNameLength = 2048
3839

3940
defaultWriteTimeout = 10 * time.Second // used if context has no deadline
4041
)

rpc/websocket_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,77 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-
391391
}
392392
}
393393
}
394+
395+
func TestWebsocketMethodNameLengthLimit(t *testing.T) {
396+
t.Parallel()
397+
398+
var (
399+
srv = newTestServer()
400+
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
401+
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
402+
)
403+
defer srv.Stop()
404+
defer httpsrv.Close()
405+
406+
client, err := DialWebsocket(context.Background(), wsURL, "")
407+
if err != nil {
408+
t.Fatalf("can't dial: %v", err)
409+
}
410+
defer client.Close()
411+
412+
// Test cases
413+
tests := []struct {
414+
name string
415+
method string
416+
params []interface{}
417+
expectedError string
418+
isSubscription bool
419+
}{
420+
{
421+
name: "valid method name",
422+
method: "test_echo",
423+
params: []interface{}{"test", 1},
424+
expectedError: "",
425+
isSubscription: false,
426+
},
427+
{
428+
name: "method name too long",
429+
method: "test_" + string(make([]byte, maxMethodNameLength+1)),
430+
params: []interface{}{"test", 1},
431+
expectedError: "method name too long",
432+
isSubscription: false,
433+
},
434+
{
435+
name: "valid subscription",
436+
method: "nftest_subscribe",
437+
params: []interface{}{"someSubscription", 1, 2},
438+
expectedError: "",
439+
isSubscription: true,
440+
},
441+
{
442+
name: "subscription name too long",
443+
method: string(make([]byte, maxMethodNameLength+1)) + "_subscribe",
444+
params: []interface{}{"newHeads"},
445+
expectedError: "subscription name too long",
446+
isSubscription: true,
447+
},
448+
}
449+
450+
for _, tt := range tests {
451+
t.Run(tt.name, func(t *testing.T) {
452+
var result interface{}
453+
err := client.Call(&result, tt.method, tt.params...)
454+
if tt.expectedError == "" {
455+
if err != nil {
456+
t.Errorf("unexpected error: %v", err)
457+
}
458+
} else {
459+
if err == nil {
460+
t.Error("expected error, got nil")
461+
} else if !strings.Contains(err.Error(), tt.expectedError) {
462+
t.Errorf("expected error containing %q, got %q", tt.expectedError, err.Error())
463+
}
464+
}
465+
})
466+
}
467+
}

0 commit comments

Comments
 (0)