From 1735b6c4b24cbe7677a4acc1d788aabed519334a Mon Sep 17 00:00:00 2001 From: David Farr Date: Tue, 28 Jan 2025 16:03:07 -0800 Subject: [PATCH] Send to exact match only when notify in poller (#537) --- internal/aio/plugin.go | 3 +++ internal/app/plugins/poll/poll.go | 6 ++++++ internal/app/plugins/poll/poll_test.go | 16 ++++++++++++++++ internal/app/subsystems/aio/sender/sender.go | 9 +++++---- 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/internal/aio/plugin.go b/internal/aio/plugin.go index 1389cee9..9e998a28 100644 --- a/internal/aio/plugin.go +++ b/internal/aio/plugin.go @@ -1,6 +1,9 @@ package aio +import "github.com/resonatehq/resonate/pkg/message" + type Message struct { + Type message.Type Data []byte Body []byte Done func(bool, error) diff --git a/internal/app/plugins/poll/poll.go b/internal/app/plugins/poll/poll.go index 078648ac..aa7601db 100644 --- a/internal/app/plugins/poll/poll.go +++ b/internal/app/plugins/poll/poll.go @@ -17,6 +17,7 @@ import ( "github.com/resonatehq/resonate/internal/kernel/t_aio" "github.com/resonatehq/resonate/internal/metrics" "github.com/resonatehq/resonate/internal/util" + "github.com/resonatehq/resonate/pkg/message" ) type Config struct { @@ -305,6 +306,11 @@ func (w *PollWorker) Process(mesg *aio.Message) { return } + if mesg.Type == message.Notify && conn.id != data.Id { + mesg.Done(false, fmt.Errorf("no connection found for group %s and id %s", data.Group, data.Id)) + return + } + // send message to connection select { case conn.ch <- mesg.Body: diff --git a/internal/app/plugins/poll/poll_test.go b/internal/app/plugins/poll/poll_test.go index b5bc9878..9754152b 100644 --- a/internal/app/plugins/poll/poll_test.go +++ b/internal/app/plugins/poll/poll_test.go @@ -13,6 +13,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/resonatehq/resonate/internal/aio" "github.com/resonatehq/resonate/internal/metrics" + "github.com/resonatehq/resonate/pkg/message" "github.com/stretchr/testify/assert" ) @@ -190,6 +191,21 @@ func TestPollPlugin(t *testing.T) { {"foo", []string{"a", "b", "c"}, "data: ok3"}, }, }, + { + name: "NotifyMustBeSameGroupAndId", + mc: 5, + connections: []*Conn{ + {"foo", "a"}, + }, + messages: []*Mesg{ + {true, &aio.Message{Type: message.Notify, Data: []byte(`{"group":"foo","id":"a"}`), Body: []byte("ok1")}}, + {false, &aio.Message{Type: message.Notify, Data: []byte(`{"group":"foo","id":"b"}`), Body: []byte("ok2")}}, + {false, &aio.Message{Type: message.Notify, Data: []byte(`{"group":"foo","id":"c"}`), Body: []byte("ok3")}}, + }, + expected: []*Resp{ + {"foo", []string{"a"}, "data: ok1"}, + }, + }, } { t.Run(tc.name, func(t *testing.T) { config := &Config{ diff --git a/internal/app/subsystems/aio/sender/sender.go b/internal/app/subsystems/aio/sender/sender.go index e1fa28ac..57ce0456 100644 --- a/internal/app/subsystems/aio/sender/sender.go +++ b/internal/app/subsystems/aio/sender/sender.go @@ -248,17 +248,17 @@ func (w *SenderWorker) Process(sqe *bus.SQE[t_aio.Submission, t_aio.Completion]) var body []byte var err error - messgType := sqe.Submission.Sender.Task.Mesg.Type + mesgType := sqe.Submission.Sender.Task.Mesg.Type - if messgType == message.Notify { + if mesgType == message.Notify { util.Assert(sqe.Submission.Sender.Promise != nil, "promise must not be nil for a notify message") body, err = json.Marshal(map[string]interface{}{ - "type": messgType, + "type": mesgType, "promise": sqe.Submission.Sender.Promise, }) } else { body, err = json.Marshal(map[string]interface{}{ - "type": messgType, + "type": mesgType, "task": sqe.Submission.Sender.Task, "href": map[string]string{ "claim": sqe.Submission.Sender.ClaimHref, @@ -277,6 +277,7 @@ func (w *SenderWorker) Process(sqe *bus.SQE[t_aio.Submission, t_aio.Completion]) counter := w.metrics.AioInFlight.WithLabelValues(plugin.String()) ok := plugin.Enqueue(&aio.Message{ + Type: mesgType, Data: recv.Data, Body: body, Done: func(success bool, err error) {