From fdb54387efccd7f6e27f713a5d16f7bf31030252 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Tue, 19 Aug 2025 12:47:26 -0400 Subject: [PATCH 1/7] checkpoint for comms mutations, validator linting clean --- api/comms/apply.go | 325 +++++++++++++++++ api/comms/chat_id.go | 21 ++ api/comms/constants.go | 21 ++ api/comms/get_new_blasts.go | 175 +++++++++ api/comms/rate_limit.go | 32 ++ api/comms/rate_limit_test.go | 202 +++++++++++ api/comms/raw_rpc.go | 16 + api/comms/rpc_log.go | 21 ++ api/comms/schema.go | 415 ++++++++++++++++++++++ api/comms/signed_request.go | 108 ++++++ api/comms/validator.go | 667 +++++++++++++++++++++++++++++++++++ api/comms_mutate.go | 67 ++++ api/server.go | 2 + 13 files changed, 2072 insertions(+) create mode 100644 api/comms/apply.go create mode 100644 api/comms/chat_id.go create mode 100644 api/comms/constants.go create mode 100644 api/comms/get_new_blasts.go create mode 100644 api/comms/rate_limit.go create mode 100644 api/comms/rate_limit_test.go create mode 100644 api/comms/raw_rpc.go create mode 100644 api/comms/rpc_log.go create mode 100644 api/comms/schema.go create mode 100644 api/comms/signed_request.go create mode 100644 api/comms/validator.go create mode 100644 api/comms_mutate.go diff --git a/api/comms/apply.go b/api/comms/apply.go new file mode 100644 index 00000000..cf1f2031 --- /dev/null +++ b/api/comms/apply.go @@ -0,0 +1,325 @@ +package comms + +import ( + "context" + "encoding/json" + "log/slog" + "strings" + "sync" + "time" + + "bridgerton.audius.co/api/dbv1" + "bridgerton.audius.co/trashid" + + // "comms.audius.co/discovery/config" + // "comms.audius.co/discovery/db" + // "comms.audius.co/discovery/db/queries" + // "comms.audius.co/discovery/misc" + // "comms.audius.co/discovery/schema" + // "github.com/jmoiron/sqlx" + + "github.com/tidwall/gjson" + // "gorm.io/gorm/logger" +) + +type RPCProcessor struct { + sync.Mutex + pool *dbv1.DBPools + validator *Validator + + // TODO + discoveryConfig *config.DiscoveryConfig +} + +func NewProcessor(pool *dbv1.DBPools, discoveryConfig *config.DiscoveryConfig) (*RPCProcessor, error) { + + // set up validator + limiter + limiter, err := NewRateLimiter() + if err != nil { + return nil, err + } + + aaoServer := "https://discoveryprovider.audius.co" + if discoveryConfig.IsStaging { + aaoServer = "https://discoveryprovider.staging.audius.co" + } + + if discoveryConfig.IsDev { + aaoServer = "http://audius-protocol-discovery-provider-1" + } + + validator := &Validator{ + pool: pool, + limiter: limiter, + aaoServer: aaoServer, + } + + proc := &RPCProcessor{ + validator: validator, + discoveryConfig: discoveryConfig, + } + + return proc, nil +} + +// TODO: replace logger +// Clean up wallet recovery (do we even need it?) +// Change format or at least naming of RpcLog +// Do we still need to check for already applied? +// - Maybe the validation needs to happen inside a transaction, since we check for existing stuff there? + +// Validates + applies a message +func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { + + logger := slog.With("sig", rpcLog.Sig) + var err error + + // check for already applied + var exists int + db.Conn.Get(&exists, `select count(*) from rpc_log where sig = $1`, rpcLog.Sig) + if exists == 1 { + logger.Debug("rpc already in log, skipping duplicate", "sig", rpcLog.Sig) + return nil + } + + startTime := time.Now() + takeSplit := func() time.Duration { + split := time.Since(startTime) + startTime = time.Now() + return split + } + + // validate signing wallet + wallet, err := misc.RecoverWallet(rpcLog.Rpc, rpcLog.Sig) + if err != nil { + logger.Warn("unable to recover wallet, skipping") + return nil + } + logger.Debug("recovered wallet", "took", takeSplit()) + + if wallet != rpcLog.FromWallet { + logger.Warn("recovered wallet no match", "recovered", wallet, "expected", rpcLog.FromWallet, "realeyd_by", rpcLog.RelayedBy) + return nil + } + + // parse raw rpc + var rawRpc RawRPC + err = json.Unmarshal(rpcLog.Rpc, &rawRpc) + if err != nil { + logger.Info(err.Error()) + return nil + } + + // check for "internal" message... + if strings.HasPrefix(rawRpc.Method, "internal.") { + err := proc.applyInternalMessage(rpcLog, &rawRpc) + if err != nil { + logger.Info("failed to apply internal rpc", "error", err) + } else { + logger.Info("applied internal RPC", "sig", rpcLog.Sig) + } + return nil + } + + // get ts + messageTs := rpcLog.RelayedAt + + userId, err := GetRPCCurrentUserID(rpcLog, &rawRpc) + if err != nil { + logger.Info("unable to get user ID") + return err // or nil? + } + + // for debugging + chatId := gjson.GetBytes(rpcLog.Rpc, "params.chat_id").String() + + logger = logger.With( + "wallet", wallet, + "userId", userId, + "relayed_by", rpcLog.RelayedBy, + "relayed_at", rpcLog.RelayedAt, + "chat_id", chatId, + "sig", rpcLog.Sig) + logger.Debug("got user", "took", takeSplit()) + + attemptApply := func() error { + + // write to db + tx, err := db.Conn.Beginx() + if err != nil { + return err + } + defer tx.Rollback() + + logger.Debug("begin tx", "took", takeSplit(), "sig", rpcLog.Sig) + + switch RPCMethod(rawRpc.Method) { + case RPCMethodChatCreate: + var params ChatCreateRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatCreate(tx, userId, messageTs, params) + if err != nil { + return err + } + case RPCMethodChatDelete: + var params ChatDeleteRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatDelete(tx, userId, params.ChatID, messageTs) + if err != nil { + return err + } + case RPCMethodChatMessage: + var params ChatMessageRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatSendMessage(tx, userId, params.ChatID, params.MessageID, messageTs, params.Message) + if err != nil { + return err + } + case RPCMethodChatReact: + var params ChatReactRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatReactMessage(tx, userId, params.ChatID, params.MessageID, params.Reaction, messageTs) + if err != nil { + return err + } + + case RPCMethodChatRead: + var params ChatReadRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + // do nothing if last active at >= message timestamp + lastActive, err := queries.LastActiveAt(tx, context.Background(), queries.ChatMembershipParams{ + ChatID: params.ChatID, + UserID: userId, + }) + if err != nil { + return err + } + if !lastActive.Valid || messageTs.After(lastActive.Time) { + err = chatReadMessages(tx, userId, params.ChatID, messageTs) + if err != nil { + return err + } + } + case RPCMethodChatPermit: + var params ChatPermitRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatSetPermissions(tx, userId, params.Permit, params.PermitList, params.Allow, messageTs) + if err != nil { + return err + } + case RPCMethodChatBlock: + var params ChatBlockRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + blockeeUserId, err := trashid.DecodeHashId(params.UserID) + if err != nil { + return err + } + err = chatBlock(tx, userId, int32(blockeeUserId), messageTs) + if err != nil { + return err + } + case RPCMethodChatUnblock: + var params ChatUnblockRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + unblockedUserId, err := trashid.DecodeHashId(params.UserID) + if err != nil { + return err + } + err = chatUnblock(tx, userId, int32(unblockedUserId), messageTs) + if err != nil { + return err + } + + case RPCMethodChatBlast: + var params ChatBlastRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + + outgoingMessages, err := chatBlast(tx, userId, messageTs, params) + if err != nil { + return err + } + // Send chat message websocket event to all recipients who have existing chats + for _, outgoingMessage := range outgoingMessages { + j, err := json.Marshal(outgoingMessage.ChatMessageRPC) + if err != nil { + slog.Error("err: invalid json", "err", err) + } else { + // TODO + // websocketNotify(json.RawMessage(j), userId, messageTs.Round(time.Microsecond)) + } + } + default: + logger.Warn("no handler for ", rawRpc.Method) + } + + logger.Debug("called handler", "took", takeSplit()) + + err = tx.Commit() + if err != nil { + return err + } + logger.Debug("commited", "took", takeSplit()) + + // TODO + // send out websocket events fire + forget style + // websocketNotify(rpcLog.Rpc, userId, messageTs.Round(time.Microsecond)) + logger.Debug("websocket push done", "took", takeSplit()) + + return nil + } + + err = attemptApply() + if err != nil { + logger.Warn("apply failed", "err", err) + } + return err +} + +// func websocketNotify(rpcJson json.RawMessage, userId int32, timestamp time.Time) { +// if chatId := gjson.GetBytes(rpcJson, "params.chat_id").String(); chatId != "" { + +// var userIds []int32 +// err := db.Conn.Select(&userIds, `select user_id from chat_member where chat_id = $1 and is_hidden = false`, chatId) +// if err != nil { +// logger.Warn("failed to load chat members for websocket push " + err.Error()) +// return +// } + +// for _, receiverUserId := range userIds { +// websocketPush(userId, receiverUserId, rpcJson, timestamp) +// } +// } else if gjson.GetBytes(rpcJson, "method").String() == "chat.blast" { +// go func() { +// // Add delay before broadcasting blast messages - see PAY-3573 +// time.Sleep(30 * time.Second) +// websocketPushAll(userId, rpcJson, timestamp) +// }() +// } +// } diff --git a/api/comms/chat_id.go b/api/comms/chat_id.go new file mode 100644 index 00000000..03605404 --- /dev/null +++ b/api/comms/chat_id.go @@ -0,0 +1,21 @@ +package comms + +import ( + "fmt" + + "bridgerton.audius.co/trashid" +) + +// ChatID return a encodedUser1:encodedUser2 ID where encodedUser1 is < encodedUser2 +// which is the convention used to make chat IDs deterministic. +// See makeChatId in SDK: packages/common/src/store/pages/chat/utils.ts +func ChatID(id1, id2 int) string { + // TODO: Handle errors + user1IdEncoded, _ := trashid.EncodeHashId(id1) + user2IdEncoded, _ := trashid.EncodeHashId(id2) + chatId := fmt.Sprintf("%s:%s", user1IdEncoded, user2IdEncoded) + if user2IdEncoded < user1IdEncoded { + chatId = fmt.Sprintf("%s:%s", user2IdEncoded, user1IdEncoded) + } + return chatId +} diff --git a/api/comms/constants.go b/api/comms/constants.go new file mode 100644 index 00000000..810f0ebb --- /dev/null +++ b/api/comms/constants.go @@ -0,0 +1,21 @@ +package comms + +var ( + SigHeader = "x-sig" + SignatureTimeToLiveMs = int64(1000 * 60 * 60 * 12) // 12 hours + + // TODO: Do we need these configurable? + // Rate limit config + RateLimitRulesBucketName = "rateLimitRules" + RateLimitTimeframeHours = "timeframeHours" + RateLimitMaxNumMessages = "maxNumMessages" + RateLimitMaxNumMessagesPerRecipient = "maxNumMessagesPerRecipient" + RateLimitMaxNumNewChats = "maxNumNewChats" + + DefaultRateLimitRules = map[string]int{ + RateLimitTimeframeHours: 24, + RateLimitMaxNumMessages: 2000, + RateLimitMaxNumMessagesPerRecipient: 1000, + RateLimitMaxNumNewChats: 100000, + } +) diff --git a/api/comms/get_new_blasts.go b/api/comms/get_new_blasts.go new file mode 100644 index 00000000..c1e9644c --- /dev/null +++ b/api/comms/get_new_blasts.go @@ -0,0 +1,175 @@ +package comms + +import ( + "context" + "database/sql" + "time" + + "bridgerton.audius.co/api/dbv1" + "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5" +) + +type BlastRow struct { + PendingChatID string `db:"-" json:"pending_chat_id"` + BlastID string `db:"blast_id" json:"blast_id"` + FromUserID int32 `db:"from_user_id" json:"from_user_id"` + Audience string `db:"audience" json:"audience"` + AudienceContentType sql.NullString `db:"audience_content_type" json:"audience_content_type"` + AudienceContentID sql.NullInt32 `db:"audience_content_id" json:"-"` + AudienceContentIDEncoded sql.NullString `db:"-" json:"audience_content_id"` + Plaintext string `db:"plaintext" json:"plaintext"` + CreatedAt time.Time `db:"created_at" json:"created_at"` +} + +type GetNewBlastsParams struct { + UserID int32 `db:"user_id" json:"user_id"` + ChatID string `db:"chat_id" json:"chat_id"` +} + +// Helper function to get new blasts, used by the mutation Validator +func GetNewBlasts(pool *dbv1.DBPools, ctx context.Context, arg GetNewBlastsParams) ([]BlastRow, error) { + + // this query is to find new blasts for the current user + // which don't already have a existing chat. + // see also: subtly different inverse query exists in chat_blast.go + // to fan out messages to existing chat + var findNewBlasts = ` + with + last_permission_change as ( + select max(t) as t from ( + select updated_at as t from chat_permissions where user_id = $1 + union + select created_at as t from chat_blocked_users where blocker_user_id = $1 + union + select to_timestamp(0) + ) as timestamp_subquery + ), + all_new as ( + select * + from chat_blast blast + where + from_user_id in ( + -- follower_audience + SELECT followee_user_id AS from_user_id + FROM follows + WHERE blast.audience = 'follower_audience' + AND follows.followee_user_id = blast.from_user_id + AND follows.follower_user_id = $1 + AND follows.is_delete = false + AND follows.created_at < blast.created_at + ) + OR from_user_id in ( + -- tipper_audience + SELECT receiver_user_id + FROM user_tips tip + WHERE blast.audience = 'tipper_audience' + AND receiver_user_id = blast.from_user_id + AND sender_user_id = $1 + AND tip.created_at < blast.created_at + ) + OR from_user_id IN ( + -- remixer_audience + SELECT og.owner_id + FROM tracks t + JOIN remixes ON remixes.child_track_id = t.track_id + JOIN tracks og ON remixes.parent_track_id = og.track_id + WHERE blast.audience = 'remixer_audience' + AND og.owner_id = blast.from_user_id + AND t.owner_id = $1 + AND ( + blast.audience_content_id IS NULL + OR ( + blast.audience_content_type = 'track' + AND blast.audience_content_id = og.track_id + ) + ) + ) + OR from_user_id IN ( + -- customer_audience + SELECT seller_user_id + FROM usdc_purchases p + WHERE blast.audience = 'customer_audience' + AND p.seller_user_id = blast.from_user_id + AND p.buyer_user_id = $1 + AND ( + audience_content_id IS NULL + OR ( + blast.audience_content_type = p.content_type::text + AND blast.audience_content_id = p.content_id + ) + ) + ) + OR from_user_id IN ( + -- coin_holder_audience via sol_user_balances + SELECT ac.user_id + FROM artist_coins ac + JOIN sol_user_balances sub ON sub.mint = ac.mint + WHERE blast.audience = 'coin_holder_audience' + AND ac.user_id = blast.from_user_id + AND sub.user_id = $1 + AND sub.balance > 0 + -- TODO: PE-6663 This isn't entirely correct yet, need to check "time of most recent membership" + AND sub.created_at < blast.created_at + ) + ) + select * from all_new + where created_at > (select t from last_permission_change) + and chat_allowed(from_user_id, $1) + order by created_at + ` + + rows, err := pool.Query(ctx, findNewBlasts, arg.UserID) + if err != nil { + return nil, err + } + + items, err := pgx.CollectRows(rows, pgx.RowToStructByName[BlastRow]) + if err != nil { + return nil, err + } + + for idx, blastRow := range items { + chatId := ChatID(int(arg.UserID), int(blastRow.FromUserID)) + items[idx].PendingChatID = chatId + + if blastRow.AudienceContentID.Valid { + encoded, _ := trashid.EncodeHashId(int(blastRow.AudienceContentID.Int32)) + items[idx].AudienceContentIDEncoded = sql.NullString{ + String: encoded, + Valid: true, + } + } + } + + rows, err = pool.Query(ctx, `select chat_id from chat_member where user_id = $1`, arg.UserID) + if err != nil { + return nil, err + } + + existingChatIdList, err := pgx.CollectRows(rows, pgx.RowTo[string]) + if err != nil { + return nil, err + } + + existingChatIds := map[string]bool{} + for _, id := range existingChatIdList { + existingChatIds[id] = true + } + + // filter out blast rows where chatIds is taken + filtered := make([]BlastRow, 0, len(items)) + for _, item := range items { + if existingChatIds[item.PendingChatID] { + continue + } + // allow caller to filter to blasts for a given chat ID + if arg.ChatID != "" && item.PendingChatID != arg.ChatID { + continue + } + filtered = append(filtered, item) + } + + return filtered, err + +} diff --git a/api/comms/rate_limit.go b/api/comms/rate_limit.go new file mode 100644 index 00000000..95796b3a --- /dev/null +++ b/api/comms/rate_limit.go @@ -0,0 +1,32 @@ +package comms + +import ( + "sync" +) + +func NewRateLimiter() (*RateLimiter, error) { + + limiter := &RateLimiter{ + limits: map[string]int{}, + } + + return limiter, nil +} + +type RateLimiter struct { + sync.RWMutex + limits map[string]int +} + +func (limiter *RateLimiter) Get(rule string) int { + limiter.RLock() + defer limiter.RUnlock() + + if val := limiter.limits[rule]; val != 0 { + return val + } + + // TODO + // return config.DefaultRateLimitRules[rule] + return 0 +} diff --git a/api/comms/rate_limit_test.go b/api/comms/rate_limit_test.go new file mode 100644 index 00000000..fd0d3220 --- /dev/null +++ b/api/comms/rate_limit_test.go @@ -0,0 +1,202 @@ +package comms + +import ( + "fmt" + "math/rand" + "strconv" + "testing" + "time" + + "comms.audius.co/discovery/db" + "comms.audius.co/discovery/misc" + "comms.audius.co/discovery/schema" + "github.com/stretchr/testify/assert" +) + +func TestRateLimit(t *testing.T) { + + // todo: update for no-nats + t.Skip() + + var err error + + // reset tables under test + _, err = db.Conn.Exec("truncate table chat cascade") + assert.NoError(t, err) + + // Add test rules + // testRules := map[string]int{ + // config.RateLimitTimeframeHours: 24, + // config.RateLimitMaxNumMessages: 3, + // config.RateLimitMaxNumMessagesPerRecipient: 2, + // config.RateLimitMaxNumNewChats: 2, + // } + // for rule, limit := range testRules { + // _, err := kv.PutString(rule, strconv.Itoa(limit)) + // assert.NoError(t, err) + // } + + tx := db.Conn.MustBegin() + + seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) + user1Id := seededRand.Int31() + user2Id := seededRand.Int31() + user3Id := seededRand.Int31() + user4Id := seededRand.Int31() + user5Id := seededRand.Int31() + + user1IdEncoded, err := misc.EncodeHashId(int(user1Id)) + assert.NoError(t, err) + user3IdEncoded, err := misc.EncodeHashId(int(user3Id)) + assert.NoError(t, err) + user4IdEncoded, err := misc.EncodeHashId(int(user4Id)) + assert.NoError(t, err) + user5IdEncoded, err := misc.EncodeHashId(int(user5Id)) + assert.NoError(t, err) + + // user1Id created a new chat with user2Id 48 hours ago + chatId1 := strconv.Itoa(seededRand.Int()) + chatTs := time.Now().UTC().Add(-time.Hour * time.Duration(48)) + _, err = tx.Exec("insert into chat (chat_id, created_at, last_message_at) values ($1, $2, $2)", chatId1, chatTs) + assert.NoError(t, err) + _, err = tx.Exec("insert into chat_member (chat_id, invited_by_user_id, invite_code, user_id, created_at) values ($1, $2, $1, $2, $4), ($1, $2, $1, $3, $4)", chatId1, user1Id, user2Id, chatTs) + assert.NoError(t, err) + + // user1Id messaged user2Id 48 hours ago + err = chatSendMessage(tx, user1Id, chatId1, "1", chatTs, "Hello") + assert.NoError(t, err) + + // user1Id messages user2Id twice now + message := "Hello today 1" + messageRpc := schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId1, message)), + } + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + assert.NoError(t, err) + err = chatSendMessage(tx, user1Id, chatId1, "2", time.Now().UTC(), message) + assert.NoError(t, err) + message = "Hello today 2" + messageRpc = schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId1, message)), + } + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + assert.NoError(t, err) + err = chatSendMessage(tx, user1Id, chatId1, "3", time.Now().UTC(), message) + assert.NoError(t, err) + + // user1Id messages user2Id a 3rd time + // Blocked by rate limiter (hit max # messages per recipient in the past 24 hours) + message = "Hello again again." + messageRpc = schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId1, message)), + } + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + assert.ErrorContains(t, err, "User has exceeded the maximum number of new messages") + + // user1Id creates a new chat with user3Id (1 chat created in 24h) + chatId2 := strconv.Itoa(seededRand.Int()) + createRpc := schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "invites": [{"user_id": "%s", "invite_code": "%s"}, {"user_id": "%s", "invite_code": "%s"}]}`, chatId2, user1IdEncoded, chatId2, user3IdEncoded, chatId2)), + } + err = testValidator.validateChatCreate(tx, user1Id, createRpc) + assert.NoError(t, err) + SetupChatWithMembers(t, tx, chatId2, user1Id, user3Id) + + // user1Id messages user3Id + // Still blocked by rate limiter (hit max # messages with user2Id in the past 24h) + message = "Hi user3Id" + messageRpc = schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId2, message)), + } + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + assert.ErrorContains(t, err, "User has exceeded the maximum number of new messages") + + // Remove message 3 from db so can test other rate limits + _, err = tx.Exec("delete from chat_message where message_id = '3'") + assert.NoError(t, err) + + // user1Id should be able to message user3Id now + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + assert.NoError(t, err) + err = chatSendMessage(tx, user1Id, chatId2, "3", time.Now().UTC(), message) + assert.NoError(t, err) + + // user1Id creates a new chat with user4Id (2 chats created in 24h) + chatId3 := strconv.Itoa(seededRand.Int()) + createRpc = schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "invites": [{"user_id": "%s", "invite_code": "%s"}, {"user_id": "%s", "invite_code": "%s"}]}`, chatId3, user1IdEncoded, chatId3, user4IdEncoded, chatId3)), + } + err = testValidator.validateChatCreate(tx, user1Id, createRpc) + assert.NoError(t, err) + SetupChatWithMembers(t, tx, chatId3, user1Id, user4Id) + + // user1Id messages user4Id + message = "Hi user4Id again" + messageRpc = schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId3, message)), + } + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + assert.NoError(t, err) + err = chatSendMessage(tx, user1Id, chatId3, "4", time.Now().UTC(), message) + assert.NoError(t, err) + + // user1Id messages user4Id again (4th message to anyone in 24h) + // Blocked by rate limiter (hit max # messages in the past 24 hours) + message = "Hi user4Id again" + messageRpc = schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId3, message)), + } + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + assert.ErrorContains(t, err, "User has exceeded the maximum number of new messages") + + // user1Id creates a new chat with user5Id (3 chats created in 24h) + // Blocked by rate limiter (hit max # new chats) + chatId4 := strconv.Itoa(seededRand.Int()) + createRpc = schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "invites": [{"user_id": "%s", "invite_code": "%s"}, {"user_id": "%s", "invite_code": "%s"}]}`, chatId4, user1IdEncoded, chatId2, user5IdEncoded, chatId4)), + } + err = testValidator.validateChatCreate(tx, user1Id, createRpc) + assert.ErrorContains(t, err, "An invited user has exceeded the maximum number of new chats") + + tx.Rollback() +} + +func TestBurstRateLimit(t *testing.T) { + var err error + + // reset tables under test + _, err = db.Conn.Exec("truncate table chat cascade") + assert.NoError(t, err) + + tx := db.Conn.MustBegin() + defer tx.Rollback() + + seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) + chatId := strconv.Itoa(seededRand.Int()) + user1Id := seededRand.Int31() + user2Id := seededRand.Int31() + + SetupChatWithMembers(t, tx, chatId, user1Id, user2Id) + + // hit the 1 second limit... send a burst of messages + for i := 1; i < 16; i++ { + + message := fmt.Sprintf("burst %d", i) + err = chatSendMessage(tx, user1Id, chatId, message, time.Now().UTC(), message) + assert.NoError(t, err, "i is", i) + + messageRpc := schema.RawRPC{ + Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId, message)), + } + err = testValidator.validateChatMessage(tx, user1Id, messageRpc) + + // first 10 messages are ok... + // and then the per-second rate limiter kicks in + if i <= 10 { + assert.NoError(t, err, "i is", i) + } else { + assert.ErrorIs(t, err, ErrMessageRateLimitExceeded, "i = ", i) + } + } + +} diff --git a/api/comms/raw_rpc.go b/api/comms/raw_rpc.go new file mode 100644 index 00000000..7cd4d5c1 --- /dev/null +++ b/api/comms/raw_rpc.go @@ -0,0 +1,16 @@ +package comms + +import ( + "encoding/json" +) + +// RawRPC matches (quicktype generated) RPC +// Except Params is a json.RawMessage instead of a quicktype approximation of a golang union type which sadly doesn't really exist. +// which is more generic + convienent to use in go code +// it should match the fields of RPC +type RawRPC struct { + CurrentUserID string `json:"current_user_id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + Timestamp int64 `json:"timestamp"` +} diff --git a/api/comms/rpc_log.go b/api/comms/rpc_log.go new file mode 100644 index 00000000..710c4a0d --- /dev/null +++ b/api/comms/rpc_log.go @@ -0,0 +1,21 @@ +package comms + +import ( + "encoding/json" + "time" +) + +// RpcLog is passed around between servers +// It has the rpc and sig (from header) +// The relay server that receives the RPC will stamp it with relayed_by and relayed_at +// relayed_by and relayed_at are used so that peer servers can consume a http-feeds style event feed +// from every peer +type RpcLog struct { + ID string `db:"id" json:"id"` + RelayedAt time.Time `db:"relayed_at" json:"relayed_at"` + AppliedAt time.Time `db:"applied_at" json:"applied_at"` + RelayedBy string `db:"relayed_by" json:"relayed_by"` + FromWallet string `db:"from_wallet" json:"from_wallet"` + Rpc json.RawMessage `db:"rpc" json:"rpc"` + Sig string `db:"sig" json:"sig"` +} diff --git a/api/comms/schema.go b/api/comms/schema.go new file mode 100644 index 00000000..25a4134a --- /dev/null +++ b/api/comms/schema.go @@ -0,0 +1,415 @@ +package comms + +type ValidateCanChatRPC struct { + Method ValidateCanChatRPCMethod `json:"method"` + Params ValidateCanChatRPCParams `json:"params"` +} + +type ValidateCanChatRPCParams struct { + ReceiverUserIDS []string `json:"receiver_user_ids"` +} + +type ChatBlastRPC struct { + Method ChatBlastRPCMethod `json:"method"` + Params ChatBlastRPCParams `json:"params"` +} + +type ChatBlastRPCParams struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID string `json:"blast_id"` + Message string `json:"message"` +} + +type ChatCreateRPC struct { + Method ChatCreateRPCMethod `json:"method"` + Params ChatCreateRPCParams `json:"params"` +} + +type ChatCreateRPCParams struct { + ChatID string `json:"chat_id"` + Invites []PurpleInvite `json:"invites"` +} + +type PurpleInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ChatDeleteRPC struct { + Method ChatDeleteRPCMethod `json:"method"` + Params ChatDeleteRPCParams `json:"params"` +} + +type ChatDeleteRPCParams struct { + ChatID string `json:"chat_id"` +} + +type ChatInviteRPC struct { + Method ChatInviteRPCMethod `json:"method"` + Params ChatInviteRPCParams `json:"params"` +} + +type ChatInviteRPCParams struct { + ChatID string `json:"chat_id"` + Invites []FluffyInvite `json:"invites"` +} + +type FluffyInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ChatMessageRPC struct { + Method ChatMessageRPCMethod `json:"method"` + Params ChatMessageRPCParams `json:"params"` +} + +type ChatMessageRPCParams struct { + ChatID string `json:"chat_id"` + IsPlaintext *bool `json:"is_plaintext,omitempty"` + Message string `json:"message"` + MessageID string `json:"message_id"` + ParentMessageID *string `json:"parent_message_id,omitempty"` + Audience *ChatBlastAudience `json:"audience,omitempty"` +} + +type ChatReactRPC struct { + Method ChatReactRPCMethod `json:"method"` + Params ChatReactRPCParams `json:"params"` +} + +type ChatReactRPCParams struct { + ChatID string `json:"chat_id"` + MessageID string `json:"message_id"` + Reaction *string `json:"reaction"` +} + +type ChatReadRPC struct { + Method ChatReadRPCMethod `json:"method"` + Params ChatReadRPCParams `json:"params"` +} + +type ChatReadRPCParams struct { + ChatID string `json:"chat_id"` +} + +type ChatBlockRPC struct { + Method ChatBlockRPCMethod `json:"method"` + Params ChatBlockRPCParams `json:"params"` +} + +type ChatBlockRPCParams struct { + UserID string `json:"user_id"` +} + +type ChatUnblockRPC struct { + Method ChatUnblockRPCMethod `json:"method"` + Params ChatUnblockRPCParams `json:"params"` +} + +type ChatUnblockRPCParams struct { + UserID string `json:"user_id"` +} + +type ChatPermitRPC struct { + Method ChatPermitRPCMethod `json:"method"` + Params ChatPermitRPCParams `json:"params"` +} + +type ChatPermitRPCParams struct { + Allow *bool `json:"allow,omitempty"` + Permit ChatPermission `json:"permit"` + PermitList []ChatPermission `json:"permit_list"` +} + +type RPCPayloadRequest struct { + Method RPCMethod `json:"method"` + Params RPCPayloadRequestParams `json:"params"` +} + +type RPCPayloadRequestParams struct { + ReceiverUserIDS []string `json:"receiver_user_ids,omitempty"` + Audience *ChatBlastAudience `json:"audience,omitempty"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID *string `json:"blast_id,omitempty"` + Message *string `json:"message,omitempty"` + ChatID *string `json:"chat_id,omitempty"` + Invites []TentacledInvite `json:"invites,omitempty"` + IsPlaintext *bool `json:"is_plaintext,omitempty"` + MessageID *string `json:"message_id,omitempty"` + ParentMessageID *string `json:"parent_message_id,omitempty"` + Reaction *string `json:"reaction"` + UserID *string `json:"user_id,omitempty"` + Allow *bool `json:"allow,omitempty"` + Permit *ChatPermission `json:"permit,omitempty"` + PermitList []ChatPermission `json:"permit_list,omitempty"` +} + +type TentacledInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type UserChat struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *string `json:"audience_content_type,omitempty"` + ChatID string `json:"chat_id"` + ChatMembers []ChatMember `json:"chat_members"` + ClearedHistoryAt string `json:"cleared_history_at"` + InviteCode string `json:"invite_code"` + IsBlast bool `json:"is_blast"` + LastMessage string `json:"last_message"` + LastMessageAt string `json:"last_message_at"` + LastMessageIsPlaintext bool `json:"last_message_is_plaintext"` + LastReadAt string `json:"last_read_at"` + RecheckPermissions bool `json:"recheck_permissions"` + UnreadMessageCount float64 `json:"unread_message_count"` +} + +type ChatMember struct { + UserID string `json:"user_id"` +} + +type ChatMessageReaction struct { + CreatedAt string `json:"created_at"` + Reaction string `json:"reaction"` + UserID string `json:"user_id"` +} + +type ChatMessageNullableReaction struct { + CreatedAt string `json:"created_at"` + Reaction *string `json:"reaction"` + UserID string `json:"user_id"` +} + +type ChatMessage struct { + CreatedAt string `json:"created_at"` + IsPlaintext bool `json:"is_plaintext"` + Message string `json:"message"` + MessageID string `json:"message_id"` + Reactions []Reaction `json:"reactions"` + SenderUserID string `json:"sender_user_id"` +} + +type Reaction struct { + CreatedAt string `json:"created_at"` + Reaction string `json:"reaction"` + UserID string `json:"user_id"` +} + +type ChatInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ChatBlastBase struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + ChatID string `json:"chat_id"` +} + +type UpgradableChatBlast struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID string `json:"blast_id"` + ChatID string `json:"chat_id"` + CreatedAt string `json:"created_at"` + FromUserID float64 `json:"from_user_id"` + PendingChatID string `json:"pending_chat_id"` + Plaintext string `json:"plaintext"` +} + +type ChatBlast struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + ChatID string `json:"chat_id"` + IsBlast bool `json:"is_blast"` + LastMessageAt string `json:"last_message_at"` +} + +type ValidatedChatPermissions struct { + CurrentUserHasPermission bool `json:"current_user_has_permission"` + PermitList []ChatPermission `json:"permit_list"` + Permits ChatPermission `json:"permits"` + UserID string `json:"user_id"` +} + +type CommsResponse struct { + Data interface{} `json:"data"` + Health Health `json:"health"` + Summary *Summary `json:"summary,omitempty"` +} + +type Health struct { + IsHealthy bool `json:"is_healthy"` +} + +type Summary struct { + NextCount float64 `json:"next_count"` + NextCursor string `json:"next_cursor"` + PrevCount float64 `json:"prev_count"` + PrevCursor string `json:"prev_cursor"` + TotalCount float64 `json:"total_count"` +} + +type ChatWebsocketEventData struct { + Metadata Metadata `json:"metadata"` + RPC RPCPayload `json:"rpc"` +} + +type Metadata struct { + ReceiverUserID string `json:"receiverUserId"` + SenderUserID string `json:"senderUserId"` + Timestamp string `json:"timestamp"` + UserID string `json:"userId"` +} + +type RPCPayload struct { + CurrentUserID string `json:"current_user_id"` + Method RPCMethod `json:"method"` + Params RPCPayloadParams `json:"params"` + Timestamp float64 `json:"timestamp"` +} + +type RPCPayloadParams struct { + ReceiverUserIDS []string `json:"receiver_user_ids,omitempty"` + Audience *ChatBlastAudience `json:"audience,omitempty"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID *string `json:"blast_id,omitempty"` + Message *string `json:"message,omitempty"` + ChatID *string `json:"chat_id,omitempty"` + Invites []StickyInvite `json:"invites,omitempty"` + IsPlaintext *bool `json:"is_plaintext,omitempty"` + MessageID *string `json:"message_id,omitempty"` + ParentMessageID *string `json:"parent_message_id,omitempty"` + Reaction *string `json:"reaction"` + UserID *string `json:"user_id,omitempty"` + Allow *bool `json:"allow,omitempty"` + Permit *ChatPermission `json:"permit,omitempty"` + PermitList []ChatPermission `json:"permit_list,omitempty"` +} + +type StickyInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ValidateCanChatRPCMethod string + +const ( + MethodUserValidateCanChat ValidateCanChatRPCMethod = "user.validate_can_chat" +) + +type ChatBlastRPCMethod string + +const ( + MethodChatBlast ChatBlastRPCMethod = "chat.blast" +) + +type ChatBlastAudience string + +const ( + CustomerAudience ChatBlastAudience = "customer_audience" + FollowerAudience ChatBlastAudience = "follower_audience" + RemixerAudience ChatBlastAudience = "remixer_audience" + TipperAudience ChatBlastAudience = "tipper_audience" + CoinHolderAudience ChatBlastAudience = "coin_holder_audience" +) + +type AudienceContentType string + +const ( + Album AudienceContentType = "album" + Track AudienceContentType = "track" +) + +type ChatCreateRPCMethod string + +const ( + MethodChatCreate ChatCreateRPCMethod = "chat.create" +) + +type ChatDeleteRPCMethod string + +const ( + MethodChatDelete ChatDeleteRPCMethod = "chat.delete" +) + +type ChatInviteRPCMethod string + +const ( + MethodChatInvite ChatInviteRPCMethod = "chat.invite" +) + +type ChatMessageRPCMethod string + +const ( + MethodChatMessage ChatMessageRPCMethod = "chat.message" +) + +type ChatReactRPCMethod string + +const ( + MethodChatReact ChatReactRPCMethod = "chat.react" +) + +type ChatReadRPCMethod string + +const ( + MethodChatRead ChatReadRPCMethod = "chat.read" +) + +type ChatBlockRPCMethod string + +const ( + MethodChatBlock ChatBlockRPCMethod = "chat.block" +) + +type ChatUnblockRPCMethod string + +const ( + MethodChatUnblock ChatUnblockRPCMethod = "chat.unblock" +) + +type ChatPermitRPCMethod string + +const ( + MethodChatPermit ChatPermitRPCMethod = "chat.permit" +) + +// Defines who the user allows to message them +type ChatPermission string + +const ( + All ChatPermission = "all" + Followees ChatPermission = "followees" + Followers ChatPermission = "followers" + None ChatPermission = "none" + Tippees ChatPermission = "tippees" + Tippers ChatPermission = "tippers" + Verified ChatPermission = "verified" +) + +type RPCMethod string + +const ( + RPCMethodChatBlast RPCMethod = "chat.blast" + RPCMethodChatBlock RPCMethod = "chat.block" + RPCMethodChatCreate RPCMethod = "chat.create" + RPCMethodChatDelete RPCMethod = "chat.delete" + RPCMethodChatInvite RPCMethod = "chat.invite" + RPCMethodChatMessage RPCMethod = "chat.message" + RPCMethodChatPermit RPCMethod = "chat.permit" + RPCMethodChatReact RPCMethod = "chat.react" + RPCMethodChatRead RPCMethod = "chat.read" + RPCMethodChatUnblock RPCMethod = "chat.unblock" + RPCMethodUserValidateCanChat RPCMethod = "user.validate_can_chat" +) diff --git a/api/comms/signed_request.go b/api/comms/signed_request.go new file mode 100644 index 00000000..ade07d09 --- /dev/null +++ b/api/comms/signed_request.go @@ -0,0 +1,108 @@ +package comms + +import ( + "encoding/base64" + "errors" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/gofiber/fiber/v2" +) + +// func userIdForSignedGet(c echo.Context) (int32, error) { +// if c.Request().Method != "GET" { +// return 0, errors.New("readSignedGet: bad method: " + c.Request().Method) +// } + +// sigBase64 := c.Request().Header.Get(signing.SigHeader) + +// // for websocket request, read from query param instead of header +// if querySig := c.QueryParam("signature"); sigBase64 == "" && querySig != "" { +// sigBase64 = querySig +// } + +// // helper function to log error if present +// logError := func(err error) (int32, error) { +// if err != nil { +// slog.Warn("ReadSignedRequest error", +// "err", err, +// "url", c.Request().URL.String(), +// "sig", sigBase64) +// } +// return 0, err +// } + +// // Check that timestamp is not too old +// timestamp, err := strconv.ParseInt(c.QueryParam("timestamp"), 0, 64) +// if err != nil { +// return logError(err) +// } + +// tsAge := time.Now().UnixMilli() - timestamp +// if tsAge < 0 { +// tsAge *= -1 +// } +// if tsAge > signing.SignatureTimeToLiveMs { +// return logError(errors.New("timestamp not current")) +// } + +// // Strip out the app_name and api_key query parameters to get the true signature payload +// u := *c.Request().URL +// q := u.Query() +// q.Del("app_name") +// q.Del("api_key") +// q.Del("signature") +// u.RawQuery = q.Encode() +// payload := []byte(u.String()) + +// wallet, err := recoverSigningWallet(sigBase64, payload) +// if err != nil { +// return logError(errors.New("failed to recoverSigningWallet")) +// } + +// userId, err := queries.GetUserIDFromWallet(db.Conn, c.Request().Context(), wallet, c.QueryParam("current_user_id")) +// if err != nil { +// return logError(fmt.Errorf("failed to get user_id for wallet: %s", wallet)) +// } + +// return userId, nil +// } + +func ReadSignedPost(c *fiber.Ctx) ([]byte, string, error) { + if c.Method() != "POST" { + return nil, "", errors.New("readSignedPost bad method: " + c.Method()) + } + + payload := c.Body() + + sigHex := c.Get(SigHeader) + wallet, err := recoverSigningWallet(sigHex, payload) + return payload, wallet, err +} + +func recoverSigningWallet(signatureHex string, signedData []byte) (string, error) { + sig, err := base64.StdEncoding.DecodeString(signatureHex) + if err != nil { + err = errors.New("bad sig header: " + err.Error()) + return "", err + } + + // recover + hash := crypto.Keccak256Hash(signedData) + pubkey, err := crypto.SigToPub(hash[:], sig) + if err != nil { + return "", err + } + + wallet := crypto.PubkeyToAddress(*pubkey).Hex() + + // TODO: Still need this? We have a function for getting these in another file + // seed the user pubkey if missing + // err = pubkeystore.SetPubkeyForWallet(wallet, pubkey) + // if err != nil { + // slog.Warn("failed to SetPubkeyForWallet", "wallet", wallet, "err", err) + // } else { + // slog.Info("SetPubkeyForWallet OK", "wallet", wallet) + // } + + return wallet, nil +} diff --git a/api/comms/validator.go b/api/comms/validator.go new file mode 100644 index 00000000..48a9357c --- /dev/null +++ b/api/comms/validator.go @@ -0,0 +1,667 @@ +package comms + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "bridgerton.audius.co/api/dbv1" + "bridgerton.audius.co/trashid" + "go.uber.org/zap" +) + +var ( + ErrMessageRateLimitExceeded = errors.New("user has exceeded the maximum number of new messages") +) + +type Validator struct { + logger *zap.Logger + pool *dbv1.DBPools + limiter *RateLimiter + aaoServer string +} + +func NewValidator(pool *dbv1.DBPools, limiter *RateLimiter, aaoServer string) *Validator { + return &Validator{ + pool: pool, + limiter: limiter, + aaoServer: aaoServer, + } +} + +func (vtor *Validator) Validate(ctx context.Context, userId int32, rawRpc RawRPC) error { + methodName := RPCMethod(rawRpc.Method) + + // actually skip timestamp check for now... + // POST endpoint will check for recency... + // but peer servers could get it later... + // and we don't want to skip message that's over a min old. + + // Always check timestamp + // if time.Now().UnixMilli()-rawRpc.Timestamp > sharedConfig.SignatureTimeToLiveMs { + // return errors.New("Invalid timestamp") + // } + + // banned? + isBanned, err := vtor.isBanned(ctx, userId) + if err != nil { + return err + } + if isBanned { + return fmt.Errorf("user_id %d is banned from chat", userId) + } + + switch methodName { + case RPCMethodChatCreate: + return vtor.validateChatCreate(ctx, userId, rawRpc) + case RPCMethodChatDelete: + return vtor.validateChatDelete(userId, rawRpc) + case RPCMethodChatMessage: + return vtor.validateChatMessage(ctx, userId, rawRpc) + case RPCMethodChatReact: + return vtor.validateChatReact(vtor.pool, ctx, userId, rawRpc) + case RPCMethodChatRead: + return vtor.validateChatRead(userId, rawRpc) + case RPCMethodChatPermit: + return vtor.validateChatPermit(userId, rawRpc) + case RPCMethodChatBlock: + return vtor.validateChatBlock(userId, rawRpc) + case RPCMethodChatUnblock: + return vtor.validateChatUnblock(userId, rawRpc) + default: + vtor.logger.Debug("no validator for " + rawRpc.Method) + } + + return nil +} + +func (vtor *Validator) isBanned(ctx context.Context, userId int32) (bool, error) { + isBanned := false + err := vtor.pool.QueryRow(ctx, "select count(user_id) = 1 from chat_ban where user_id = $1 and is_banned = true", userId).Scan(&isBanned) + if err != nil { + return false, err + } + return isBanned, nil +} + +func (vtor *Validator) validateChatCreate(ctx context.Context, userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatCreateRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate chatId does not already exist + query := "select count(*) from chat where chat_id = $1;" + var chatCount int + err = vtor.pool.QueryRow(ctx, query, params.ChatID).Scan(&chatCount) + if err != nil { + return err + } + if chatCount != 0 { + return errors.New("Chat already exists") + } + + if len(params.Invites) != 2 { + return errors.New("Chat must have 2 members") + } + + user1, err := trashid.DecodeHashId(params.Invites[0].UserID) + if err != nil { + return err + } + user2, err := trashid.DecodeHashId(params.Invites[1].UserID) + if err != nil { + return err + } + + receiver := int32(user1) + if receiver == userId { + receiver = int32(user2) + } + + // Check that the creator is non-abusive + err = validateSenderPassesAbuseCheck(vtor.pool, ctx, userId, vtor.aaoServer) + if err != nil { + return err + } + + // if recipient is creating a chat from a blast + // we ignore the receiver's inbox settings + // because receiver has sent a blast to this user. + { + hasBlast, err := hasNewBlastFromUser(vtor.pool, ctx, userId, receiver) + if err != nil { + return err + } + if hasBlast { + return nil + } + } + + // validate receiver permits chats from sender + err = validatePermissions(vtor.pool, ctx, userId, receiver) + if err != nil { + return err + } + + // validate does not exceed new chat rate limit for any invited users + var users []int32 + for _, invite := range params.Invites { + userId, err := trashid.DecodeHashId(invite.UserID) + if err != nil { + return err + } + users = append(users, int32(userId)) + } + err = vtor.validateNewChatRateLimit(vtor.pool, ctx, users) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatMessage(ctx context.Context, userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatMessageRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + // validate not blocked and can chat according to receiver's inbox permission settings + err = validatePermittedToMessage(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + // validate does not exceed new message rate limit + err = vtor.validateNewMessageRateLimit(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatReact(pool *dbv1.DBPools, ctx context.Context, userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatReactRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + // validate message exists in chat + var exists bool + err = pool.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM chat_message WHERE chat_id = $1 AND message_id = $2)`, params.ChatID, params.MessageID).Scan(&exists) + if err != nil { + return err + } + if !exists { + return errors.New("message does not exist in chat") + } + + // validate not blocked and can chat according to receiver's inbox permission settings + err = validatePermittedToMessage(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatRead(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatReadRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, context.Background(), userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatPermit(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatPermitRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatBlock(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatBlockRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatUnblock(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatBlockRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate that params.UserID is currently blocked by userId + blockeeUserId, err := trashid.DecodeHashId(params.UserID) + if err != nil { + return err + } + + var exists bool + err = vtor.pool.QueryRow(context.Background(), ` + select exists( + select 1 from chat_blocked_users + where blocker_user_id = $1 and blockee_user_id = $2 + ) + `, userId, blockeeUserId).Scan(&exists) + if err != nil { + return err + } + if !exists { + return errors.New("user is not blocked") + } + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatDelete(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatDeleteRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, context.Background(), userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +// Calculate cursor from rate limit timeframe +func (vtor *Validator) calculateRateLimitCursor(timeframe int) time.Time { + return time.Now().UTC().Add(-time.Hour * time.Duration(timeframe)) +} + +func (vtor *Validator) validateNewChatRateLimit(pool *dbv1.DBPools, ctx context.Context, users []int32) error { + var err error + + // rate_limit_seconds + + limiter := vtor.limiter + timeframe := limiter.Get(RateLimitTimeframeHours) + + // Max num of new chats permitted per timeframe + maxNumChats := limiter.Get(RateLimitMaxNumNewChats) + + cursor := vtor.calculateRateLimitCursor(timeframe) + + // Build the query with proper placeholders for the IN clause + query := ` + WITH counts AS ( + SELECT COUNT(*) AS count + FROM chat + JOIN chat_member on chat.chat_id = chat_member.chat_id + WHERE chat_member.user_id = ANY($1) AND chat.created_at > $2 + GROUP BY chat_member.user_id + ) + SELECT COALESCE(MAX(count), 0) FROM counts; + ` + + var numChats int + err = pool.QueryRow(ctx, query, users, cursor).Scan(&numChats) + if err != nil { + return err + } + if numChats >= maxNumChats { + vtor.logger.Info("hit rate limit (new chats)", zap.Any("users", users)) + return errors.New("An invited user has exceeded the maximum number of new chats") + } + + return nil +} + +func (vtor *Validator) validateNewMessageRateLimit(pool *dbv1.DBPools, ctx context.Context, userId int32, chatId string) error { + var err error + + // BurstRateLimit + { + query := ` + select + sum(case when created_at > now() - interval '1 second' then 1 else 0 end) as s1, + sum(case when created_at > now() - interval '10 seconds' then 1 else 0 end) as s10, + sum(case when created_at > now() - interval '60 seconds' then 1 else 0 end) as s60 + from chat_message + where user_id = $1 + and created_at > now() - interval '60 seconds'; + ` + var s1, s10, s60 sql.NullInt64 + err = pool.QueryRow(ctx, query, userId).Scan(&s1, &s10, &s60) + if err != nil { + slog.Error("burst rate limit query failed", "err", err) + } + + // 10 per second in last second + if s1.Int64 > 10 { + slog.Warn("message rate limit exceeded", "bucket", "1s", "user_id", userId, "count", s1) + return ErrMessageRateLimitExceeded + + } + + // 7 per second for last 10 seconds + if s10.Int64 > 70 { + slog.Warn("message rate limit exceeded", "bucket", "10s", "user_id", userId, "count", s10) + return ErrMessageRateLimitExceeded + } + + // 5 per second for last 60 seconds + if s60.Int64 > 300 { + slog.Warn("message rate limit exceeded", "bucket", "60s", "user_id", userId, "count", s60) + return ErrMessageRateLimitExceeded + } + } + + limiter := vtor.limiter + timeframe := limiter.Get(RateLimitTimeframeHours) + + // Max number of new messages permitted per timeframe + maxNumMessages := limiter.Get(RateLimitMaxNumMessages) + + // Max number of new messages permitted per recipient (chat) per timeframe + maxNumMessagesPerRecipient := limiter.Get(RateLimitMaxNumMessagesPerRecipient) + + // Cursor for rate limit timeframe + cursor := vtor.calculateRateLimitCursor(timeframe) + + // Check total message count and max messages per chat + query := ` + WITH counts_per_chat AS ( + SELECT COUNT(*) + FROM chat_message + WHERE user_id = $1 and created_at > $2 + GROUP BY chat_id + ) + SELECT COALESCE(SUM(count), 0) AS total_count, COALESCE(MAX(count), 0) as max_count_per_chat FROM counts_per_chat; + ` + + var totalCount, maxCountPerChat int + err = pool.QueryRow(ctx, query, userId, cursor).Scan(&totalCount, &maxCountPerChat) + if err != nil { + return err + } + if totalCount >= maxNumMessages || maxCountPerChat >= maxNumMessagesPerRecipient { + if totalCount >= maxNumMessages { + vtor.logger.Info("hit rate limit (total count new messages)", zap.Int32("user", userId), zap.String("chat", chatId)) + } + if maxCountPerChat >= maxNumMessagesPerRecipient { + vtor.logger.Info("hit rate limit (new messages per recipient)", zap.Int32("user", userId), zap.String("chat", chatId)) + } + return ErrMessageRateLimitExceeded + } + + return nil +} + +func validateChatMembership(pool *dbv1.DBPools, ctx context.Context, userId int32, chatId string) error { + var exists bool + err := pool.QueryRow(ctx, `select exists(select 1 from chat_member where user_id = $1 and chat_id = $2)`, userId, chatId).Scan(&exists) + if err != nil { + return err + } + if !exists { + return errors.New("user is not a member of this chat") + } + return nil +} + +// TODO: REMOVE? +// Recheck chat permissions before sending further messages if a member of the chat +// has cleared their chat history +// func RecheckPermissionsRequired(lastMessageAt time.Time, members []db.ChatMember) bool { +// for _, member := range members { +// if member.ClearedHistoryAt.Valid && member.ClearedHistoryAt.Time.After(lastMessageAt) { +// return true +// } +// } +// return false +// } + +func validatePermissions(pool *dbv1.DBPools, ctx context.Context, sender int32, receiver int32) error { + permissionFailure := errors.New("Not permitted to send messages to this user") + + ok := false + err := pool.QueryRow(ctx, `select chat_allowed($1, $2)`, sender, receiver).Scan(&ok) + if err != nil { + return err + } + if !ok { + return permissionFailure + } + return nil + +} + +func validatePermittedToMessage(pool *dbv1.DBPools, ctx context.Context, userId int32, chatId string) error { + // Single query that validates: + // 1. Chat has exactly 2 members + // 2. User is a member of the chat + // 3. User has permission to message the other member + query := ` + WITH chat_members AS ( + SELECT user_id + FROM chat_member + WHERE chat_id = $1 + ), + member_count AS ( + SELECT COUNT(*) as count + FROM chat_members + ), + other_member AS ( + SELECT user_id + FROM chat_members + WHERE user_id != $2 + ) + SELECT + CASE + WHEN mc.count != 2 THEN false + WHEN NOT EXISTS (SELECT 1 FROM chat_members WHERE user_id = $2) THEN false + WHEN NOT chat_allowed($2, om.user_id) THEN false + ELSE true + END as is_permitted + FROM member_count mc + CROSS JOIN other_member om + ` + + var isPermitted bool + err := pool.QueryRow(ctx, query, chatId, userId).Scan(&isPermitted) + if err != nil { + if err == sql.ErrNoRows { + return errors.New("Chat must have 2 members") + } + return err + } + + if !isPermitted { + return errors.New("Not permitted to send messages to this user") + } + + return nil +} + +var ErrAttestationFailed = errors.New("attestation failed") + +func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, userId int32, aaoServer string) error { + // Keeping this somewhat opaque as it gets sent to client + var handle string + err := pool.QueryRow(ctx, `SELECT handle FROM users WHERE user_id = $1`, userId).Scan(&handle) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("user %d not found", userId) + } + return err + } + + url := fmt.Sprintf("%s/attestation/%s", aaoServer, handle) + // Dummy challenge for now to mitigate + requestBody := []byte(`{ "challengeId": "x", "challengeSpecifier": "x", "amount": 0 }`) + resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + if err != nil { + slog.Error("Error checking user attestation", "error", err, "handle", handle) + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + slog.Warn("User failed AAO check", "userId", userId, "status", resp.StatusCode) + return ErrAttestationFailed + } + return nil +} + +// HasNewBlastFromUser efficiently checks if a new blast exists from a specific user +// without fetching all blast data. Returns true if a valid blast exists, false otherwise. +func hasNewBlastFromUser(pool *dbv1.DBPools, ctx context.Context, userID int32, fromUserID int32) (bool, error) { + // Construct the expected chat ID for this user pair + expectedChatID := ChatID(int(userID), int(fromUserID)) + + // This query checks for the existence of a new blast from a specific user + // using the same logic as GetNewBlasts but optimized for existence check + var hasNewBlast = ` + with + last_permission_change as ( + select max(t) as t from ( + select updated_at as t from chat_permissions where user_id = $1 + union + select created_at as t from chat_blocked_users where blocker_user_id = $1 + union + select to_timestamp(0) + ) as timestamp_subquery + ) + select exists( + select 1 + from chat_blast blast + where + blast.from_user_id = $2 + and blast.created_at > (select t from last_permission_change) + and chat_allowed(blast.from_user_id, $1) + and not exists ( + select 1 from chat_member cm + where cm.user_id = $1 and cm.chat_id = $3 + ) + and ( + -- follower_audience + (blast.audience = 'follower_audience' and exists ( + SELECT 1 + FROM follows + WHERE follows.followee_user_id = blast.from_user_id + AND follows.follower_user_id = $1 + AND follows.is_delete = false + AND follows.created_at < blast.created_at + )) + OR + -- tipper_audience + (blast.audience = 'tipper_audience' and exists ( + SELECT 1 + FROM user_tips tip + WHERE receiver_user_id = blast.from_user_id + AND sender_user_id = $1 + AND tip.created_at < blast.created_at + )) + OR + -- remixer_audience + (blast.audience = 'remixer_audience' and exists ( + SELECT 1 + FROM tracks t + JOIN remixes ON remixes.child_track_id = t.track_id + JOIN tracks og ON remixes.parent_track_id = og.track_id + WHERE og.owner_id = blast.from_user_id + AND t.owner_id = $1 + AND ( + blast.audience_content_id IS NULL + OR ( + blast.audience_content_type = 'track' + AND blast.audience_content_id = og.track_id + ) + ) + )) + OR + -- customer_audience + (blast.audience = 'customer_audience' and exists ( + SELECT 1 + FROM usdc_purchases p + WHERE p.seller_user_id = blast.from_user_id + AND p.buyer_user_id = $1 + AND ( + blast.audience_content_id IS NULL + OR ( + blast.audience_content_type = p.content_type::text + AND blast.audience_content_id = p.content_id + ) + ) + )) + OR + -- coin_holder_audience + (blast.audience = 'coin_holder_audience' and exists ( + SELECT 1 + FROM artist_coins ac + JOIN sol_user_balances sub ON sub.mint = ac.mint + WHERE ac.user_id = blast.from_user_id + AND sub.user_id = $1 + AND sub.balance > 0 + -- TODO: PE-6663 This isn't entirely correct yet, need to check "time of most recent membership" + AND sub.created_at < blast.created_at + )) + ) + )` + + var exists bool + err := pool.QueryRow(ctx, hasNewBlast, userID, fromUserID, expectedChatID).Scan(&exists) + if err != nil { + return false, err + } + + return exists, nil +} diff --git a/api/comms_mutate.go b/api/comms_mutate.go new file mode 100644 index 00000000..b64b54f9 --- /dev/null +++ b/api/comms_mutate.go @@ -0,0 +1,67 @@ +package api + +import ( + "encoding/json" + "errors" + + comms "bridgerton.audius.co/api/comms" + + "github.com/AudiusProject/audiusd/pkg/logger" + "github.com/gofiber/fiber/v2" +) + +/* TODO List +- Decide if we want to validate first and then apply or validate inside apply within a transaction +- Attach instance of rpc processor or validator to the server? +*/ + +func (app *ApiServer) mutateChat(c *fiber.Ctx) error { + payload, wallet, err := comms.ReadSignedPost(c) + if err != nil { + return fiber.NewError(fiber.StatusBadRequest, "bad request: "+err.Error()) + } + + // unmarshal RPC and call validator + var rawRpc comms.RawRPC + err = json.Unmarshal(payload, &rawRpc) + if err != nil { + return fiber.NewError(fiber.StatusBadRequest, "bad request: "+err.Error()) + } + + // + // rpcLog := &schema.RpcLog{ + // RelayedBy: s.config.MyHost, + // RelayedAt: time.Now(), + // FromWallet: wallet, + // Rpc: payload, + // Sig: c.Request().Header.Get(signing.SigHeader), + // } + + // authedWallet := app.getAuthedWallet(c) + userId, err := app.getUserIDFromWallet(c.Context(), wallet) + if err != nil { + return err + } + + // userId, err := rpcz.GetRPCCurrentUserID(rpcLog, &rawRpc) + // if err != nil { + // return c.String(400, "wallet not found: "+err.Error()) + // } + + // call validator + err = s.proc.Validate(userId, rawRpc) + if err != nil { + if errors.Is(err, rpcz.ErrAttestationFailed) { + return c.JSON(403, "bad request: "+err.Error()) + } + return c.JSON(400, "bad request: "+err.Error()) + } + + ok, err := s.proc.Apply(rpcLog) + if err != nil { + logger.Warn(string(payload), "wallet", wallet, "err", err) + return err + } + logger.Debug(string(payload), "wallet", wallet, "relay", true) + return c.JSON(200, ok) +} diff --git a/api/server.go b/api/server.go index 0db5b7a7..ad1c4a15 100644 --- a/api/server.go +++ b/api/server.go @@ -500,6 +500,8 @@ func NewApiServer(config config.Config) *ApiServer { comms.Get("/blasts", app.getNewBlasts) + comms.Post("/mutate", app.mutateChat) + // Block confirmation app.Get("/block_confirmation", app.BlockConfirmation) app.Get("/block-confirmation", app.BlockConfirmation) From e2e982df5cce9cef1c9800e650bcb9633e7ea68b Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:11:38 -0400 Subject: [PATCH 2/7] migrate rpc processor --- api/comms/chat.go | 454 +++++++++++++++++++++++ api/comms/chat_blast.go | 120 ++++++ api/comms/chat_id.go | 5 + api/comms/get_new_blasts.go | 175 --------- api/comms/{apply.go => rpc_processor.go} | 183 +++++---- api/comms/schema.go | 14 +- api/comms/validator.go | 20 +- config/config.go | 4 + 8 files changed, 710 insertions(+), 265 deletions(-) create mode 100644 api/comms/chat.go create mode 100644 api/comms/chat_blast.go delete mode 100644 api/comms/get_new_blasts.go rename api/comms/{apply.go => rpc_processor.go} (52%) diff --git a/api/comms/chat.go b/api/comms/chat.go new file mode 100644 index 00000000..abcfb60d --- /dev/null +++ b/api/comms/chat.go @@ -0,0 +1,454 @@ +package comms + +import ( + "context" + "time" + + "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +func chatCreate(tx pgx.Tx, ctx context.Context, userId int32, ts time.Time, params ChatCreateRPCParams) error { + var err error + + // first find any blasts that should seed this chat ... + var blasts []blastRow + for _, invite := range params.Invites { + invitedUserId, err := trashid.DecodeHashId(invite.UserID) + if err != nil { + return err + } + + pending, err := getNewBlasts(tx, context.Background(), getNewBlastsParams{ + UserID: int32(invitedUserId), + ChatID: params.ChatID, + }) + if err != nil { + return err + } + blasts = append(blasts, pending...) + } + + // it is possible that two conflicting chats get created at the same time + // in which case there will be two different chat secrets + // to deterministically resolve this, if there is a conflict + // we keep the chat with the earliest relayed_at (created_at) timestamp + _, err = tx.Exec(ctx, ` + insert into chat + (chat_id, created_at, last_message_at) + values + ($1, $2, $2) + on conflict (chat_id) + do update set created_at = $2, last_message_at = $2 where chat.created_at > $2 + `, params.ChatID, ts) + if err != nil { + return err + } + + for _, invite := range params.Invites { + + invitedUserId, err := trashid.DecodeHashId(invite.UserID) + if err != nil { + return err + } + + // similar to above... if there is a conflict when creating chat_member records + // keep the version with the earliest relayed_at (created_at) timestamp. + _, err = tx.Exec(ctx, ` + insert into chat_member + (chat_id, invited_by_user_id, invite_code, user_id, created_at) + values + ($1, $2, $3, $4, $5) + on conflict (chat_id, user_id) + do update set invited_by_user_id=$2, invite_code=$3, created_at=$5 where chat_member.created_at > $5`, + params.ChatID, userId, invite.InviteCode, invitedUserId, ts) + if err != nil { + return err + } + + } + + for _, blast := range blasts { + _, err = tx.Exec(ctx, ` + insert into chat_message + (message_id, chat_id, user_id, created_at, blast_id) + values + ($1, $2, $3, $4, $5) + on conflict do nothing + `, BlastMessageID(blast.BlastID, params.ChatID), params.ChatID, blast.FromUserID, blast.CreatedAt, blast.BlastID) + if err != nil { + return err + } + } + + err = chatUpdateLatestFields(tx, ctx, params.ChatID) + + return err +} + +func chatDelete(tx pgx.Tx, ctx context.Context, userId int32, chatId string, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, "update chat_member set cleared_history_at = $1, last_active_at = $1, unread_count = 0, is_hidden = true where chat_id = $2 and user_id = $3", messageTimestamp, chatId, userId) + return err +} + +func chatUpdateLatestFields(tx pgx.Tx, ctx context.Context, chatId string) error { + // universal latest message thing + _, err := tx.Exec(ctx, ` + with latest as ( + select + m.chat_id, + m.created_at, + m.ciphertext, + m.blast_id, + b.plaintext + from + chat_message m + left join chat_blast b using (blast_id) + where m.chat_id = $1 + order by m.created_at desc + limit 1 + ) + update chat c + set + last_message_at = latest.created_at, + last_message = coalesce(latest.ciphertext, latest.plaintext), + last_message_is_plaintext = latest.blast_id is not null + from latest + where c.chat_id = latest.chat_id; + `, chatId) + if err != nil { + return err + } + + // set chat_member.is_hidden to false + // if there are any non-blast messages, reactions, + // or any blasts from the other party after cleared_history_at + _, err = tx.Exec(ctx, ` + UPDATE chat_member member + SET is_hidden = NOT EXISTS( + + -- Check for chat messages + SELECT msg.message_id + FROM chat_message msg + LEFT JOIN chat_blast b USING (blast_id) + WHERE msg.chat_id = member.chat_id + AND (cleared_history_at IS NULL OR msg.created_at > cleared_history_at) + AND (msg.blast_id IS NULL OR b.from_user_id != member.user_id) + + UNION + + -- Check for chat message reactions + SELECT r.message_id + FROM chat_message_reactions r + LEFT JOIN chat_message msg ON r.message_id = msg.message_id + WHERE msg.chat_id = member.chat_id + AND r.user_id != member.user_id + AND (cleared_history_at IS NULL OR (r.created_at > cleared_history_at AND msg.created_at > cleared_history_at)) + ), + unread_count = ( + select count(*) + from chat_message msg + where msg.created_at > COALESCE(member.last_active_at, '1970-01-01'::timestamp) + and msg.user_id != member.user_id + and msg.chat_id = member.chat_id + ) + WHERE member.chat_id = $1 + `, chatId) + return err +} + +func chatSendMessage(tx pgx.Tx, ctx context.Context, userId int32, chatId string, messageId string, messageTimestamp time.Time, ciphertext string) error { + var err error + + _, err = tx.Exec(ctx, "insert into chat_message (message_id, chat_id, user_id, created_at, ciphertext) values ($1, $2, $3, $4, $5)", + messageId, chatId, userId, messageTimestamp, ciphertext) + if err != nil { + return err + } + + // update chat's info on last message + err = chatUpdateLatestFields(tx, ctx, chatId) + if err != nil { + return err + } + + // sending a message implicitly marks activity for sender... + err = chatReadMessages(tx, ctx, userId, chatId, messageTimestamp) + if err != nil { + return err + } + + return err +} + +func chatReactMessage(tx pgx.Tx, ctx context.Context, userId int32, chatId string, messageId string, reaction *string, messageTimestamp time.Time) error { + var err error + if reaction != nil { + _, err = tx.Exec(ctx, ` + insert into chat_message_reactions + (user_id, message_id, reaction, created_at, updated_at) + values + ($1, $2, $3, $4, $4) + on conflict (user_id, message_id) + do update set reaction = $3, updated_at = $4 where chat_message_reactions.updated_at < $4`, + userId, messageId, *reaction, messageTimestamp) + } else { + _, err = tx.Exec(ctx, "delete from chat_message_reactions where user_id = $1 and message_id = $2 and updated_at < $3", userId, messageId, messageTimestamp) + } + if err != nil { + return err + } + + // update chat's info on reaction + err = chatUpdateLatestFields(tx, ctx, chatId) + return err +} + +func chatReadMessages(tx pgx.Tx, ctx context.Context, userId int32, chatId string, readTimestamp time.Time) error { + _, err := tx.Exec(ctx, "update chat_member set unread_count = 0, last_active_at = $1 where chat_id = $2 and user_id = $3", + readTimestamp, chatId, userId) + return err +} + +var permissions = []ChatPermission{ + ChatPermissionFollowees, + ChatPermissionFollowers, + ChatPermissionTippees, + ChatPermissionTippers, + ChatPermissionVerified, +} + +// Helper function to check if a permit is in the permitList +func isInPermitList(permit ChatPermission, permitList []ChatPermission) bool { + for _, p := range permitList { + if p == permit { + return true + } + } + return false +} + +func updatePermissions(tx pgx.Tx, ctx context.Context, userId int32, permit ChatPermission, permitAllowed bool, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, ` + insert into chat_permissions (user_id, permits, allowed, updated_at) + values ($1, $2, $3, $4) + on conflict (user_id, permits) + do update set allowed = $3 where chat_permissions.updated_at < $4 + `, userId, permit, permitAllowed, messageTimestamp) + return err +} + +func chatSetPermissions(tx pgx.Tx, ctx context.Context, userId int32, permits ChatPermission, permitList []ChatPermission, allow *bool, messageTimestamp time.Time) error { + + // if "all" or "none" or is singular permission style (allow == nil) delete any old rows + if allow == nil || permits == ChatPermissionAll || permits == ChatPermissionNone || isInPermitList(ChatPermissionAll, permitList) || isInPermitList(ChatPermissionNone, permitList) { + _, err := tx.Exec(ctx, ` + delete from chat_permissions where user_id = $1 and updated_at < $2 + `, userId, messageTimestamp) + if err != nil { + return err + } + } + + // old: singular permission style + if allow == nil { + // insert + _, err := tx.Exec(ctx, ` + insert into chat_permissions (user_id, permits, updated_at) + values ($1, $2, $3) + on conflict do nothing`, userId, permits, messageTimestamp) + return err + } + + // Special case for "all" and "none" - no other rows should be inserted + if isInPermitList(ChatPermissionAll, permitList) { + err := updatePermissions(tx, ctx, userId, ChatPermissionAll, true, messageTimestamp) + return err + } else if isInPermitList(ChatPermissionNone, permitList) { + err := updatePermissions(tx, ctx, userId, ChatPermissionNone, true, messageTimestamp) + return err + } + + // new: multiple (checkbox) permission style + for _, permit := range permissions { + permitAllowed := isInPermitList(permit, permitList) + err := updatePermissions(tx, ctx, userId, permit, permitAllowed, messageTimestamp) + if err != nil { + return err + } + } + return nil +} + +func chatBlock(tx pgx.Tx, ctx context.Context, userId int32, blockeeUserId int32, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, "insert into chat_blocked_users (blocker_user_id, blockee_user_id, created_at) values ($1, $2, $3) on conflict do nothing", userId, blockeeUserId, messageTimestamp) + return err +} + +func chatUnblock(tx pgx.Tx, ctx context.Context, userId int32, unblockedUserId int32, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, "delete from chat_blocked_users where blocker_user_id = $1 and blockee_user_id = $2 and created_at < $3", userId, unblockedUserId, messageTimestamp) + return err +} + +type blastRow struct { + PendingChatID string `db:"-" json:"pending_chat_id"` + BlastID string `db:"blast_id" json:"blast_id"` + FromUserID int32 `db:"from_user_id" json:"from_user_id"` + Audience string `db:"audience" json:"audience"` + AudienceContentType pgtype.Text `db:"audience_content_type" json:"audience_content_type"` + AudienceContentID pgtype.Int4 `db:"audience_content_id" json:"-"` + AudienceContentIDEncoded pgtype.Text `db:"-" json:"audience_content_id"` + Plaintext string `db:"plaintext" json:"plaintext"` + CreatedAt time.Time `db:"created_at" json:"created_at"` +} + +type getNewBlastsParams struct { + UserID int32 `db:"user_id" json:"user_id"` + ChatID string `db:"chat_id" json:"chat_id"` +} + +// Helper function to get new blasts as potential chat seeds for creating a chat +func getNewBlasts(tx pgx.Tx, ctx context.Context, arg getNewBlastsParams) ([]blastRow, error) { + + // this query is to find new blasts for the current user + // which don't already have a existing chat. + // see also: subtly different inverse query exists in chat_blast.go + // to fan out messages to existing chat + var findNewBlasts = ` + with + last_permission_change as ( + select max(t) as t from ( + select updated_at as t from chat_permissions where user_id = $1 + union + select created_at as t from chat_blocked_users where blocker_user_id = $1 + union + select to_timestamp(0) + ) as timestamp_subquery + ), + all_new as ( + select * + from chat_blast blast + where + from_user_id in ( + -- follower_audience + SELECT followee_user_id AS from_user_id + FROM follows + WHERE blast.audience = 'follower_audience' + AND follows.followee_user_id = blast.from_user_id + AND follows.follower_user_id = $1 + AND follows.is_delete = false + AND follows.created_at < blast.created_at + ) + OR from_user_id in ( + -- tipper_audience + SELECT receiver_user_id + FROM user_tips tip + WHERE blast.audience = 'tipper_audience' + AND receiver_user_id = blast.from_user_id + AND sender_user_id = $1 + AND tip.created_at < blast.created_at + ) + OR from_user_id IN ( + -- remixer_audience + SELECT og.owner_id + FROM tracks t + JOIN remixes ON remixes.child_track_id = t.track_id + JOIN tracks og ON remixes.parent_track_id = og.track_id + WHERE blast.audience = 'remixer_audience' + AND og.owner_id = blast.from_user_id + AND t.owner_id = $1 + AND ( + blast.audience_content_id IS NULL + OR ( + blast.audience_content_type = 'track' + AND blast.audience_content_id = og.track_id + ) + ) + ) + OR from_user_id IN ( + -- customer_audience + SELECT seller_user_id + FROM usdc_purchases p + WHERE blast.audience = 'customer_audience' + AND p.seller_user_id = blast.from_user_id + AND p.buyer_user_id = $1 + AND ( + audience_content_id IS NULL + OR ( + blast.audience_content_type = p.content_type::text + AND blast.audience_content_id = p.content_id + ) + ) + ) + OR from_user_id IN ( + -- coin_holder_audience via sol_user_balances + SELECT ac.user_id + FROM artist_coins ac + JOIN sol_user_balances sub ON sub.mint = ac.mint + WHERE blast.audience = 'coin_holder_audience' + AND ac.user_id = blast.from_user_id + AND sub.user_id = $1 + AND sub.balance > 0 + -- TODO: PE-6663 This isn't entirely correct yet, need to check "time of most recent membership" + AND sub.created_at < blast.created_at + ) + ) + select * from all_new + where created_at > (select t from last_permission_change) + and chat_allowed(from_user_id, $1) + order by created_at + ` + + rows, err := tx.Query(ctx, findNewBlasts, arg.UserID) + if err != nil { + return nil, err + } + + items, err := pgx.CollectRows(rows, pgx.RowToStructByName[blastRow]) + if err != nil { + return nil, err + } + + for idx, blastRow := range items { + chatId := ChatID(int(arg.UserID), int(blastRow.FromUserID)) + items[idx].PendingChatID = chatId + + if blastRow.AudienceContentID.Valid { + encoded, _ := trashid.EncodeHashId(int(blastRow.AudienceContentID.Int32)) + items[idx].AudienceContentIDEncoded.String = encoded + items[idx].AudienceContentIDEncoded.Valid = true + } + } + + rows, err = tx.Query(ctx, `select chat_id from chat_member where user_id = $1`, arg.UserID) + if err != nil { + return nil, err + } + + existingChatIdList, err := pgx.CollectRows(rows, pgx.RowTo[string]) + if err != nil { + return nil, err + } + + existingChatIds := map[string]bool{} + for _, id := range existingChatIdList { + existingChatIds[id] = true + } + + // filter out blast rows where chatIds is taken + filtered := make([]blastRow, 0, len(items)) + for _, item := range items { + if existingChatIds[item.PendingChatID] { + continue + } + // allow caller to filter to blasts for a given chat ID + if arg.ChatID != "" && item.PendingChatID != arg.ChatID { + continue + } + filtered = append(filtered, item) + } + + return filtered, err + +} diff --git a/api/comms/chat_blast.go b/api/comms/chat_blast.go new file mode 100644 index 00000000..e71393d4 --- /dev/null +++ b/api/comms/chat_blast.go @@ -0,0 +1,120 @@ +package comms + +import ( + "context" + "time" + + "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5" +) + +/* +todo: + +- maybe blast_id should be computed like: `md5(from_user_id || audience || plaintext)` + +*/ +// Result struct to hold chat_id and to_user_id +type ChatBlastResult struct { + ChatID string `db:"chat_id"` + ToUserID int32 `db:"to_user_id"` +} + +type OutgoingChatMessage struct { + ChatMessageRPC ChatMessageRPC `json:"chat_message_rpc"` +} + +func chatBlast(tx pgx.Tx, ctx context.Context, userId int32, ts time.Time, params ChatBlastRPCParams) ([]OutgoingChatMessage, error) { + var audienceContentID *int + if params.AudienceContentID != nil { + id, _ := trashid.DecodeHashId(*params.AudienceContentID) + audienceContentID = &id + } + + // insert params.Message into chat_blast table + _, err := tx.Exec(ctx, ` + insert into chat_blast + (blast_id, from_user_id, audience, audience_content_type, audience_content_id, plaintext, created_at) + values + ($1, $2, $3, $4, $5, $6, $7) + on conflict (blast_id) + do nothing + `, params.BlastID, userId, params.Audience, params.AudienceContentType, audienceContentID, params.Message, ts) + if err != nil { + return nil, err + } + + // fan out messages to existing threads + // see also: similar but subtly different inverse query in `getNewBlasts helper in chat.go` + var results []ChatBlastResult + + fanOutSql := ` + WITH targ AS ( + SELECT + blast_id, + from_user_id, + to_user_id, + member_b.chat_id + FROM chat_blast + JOIN chat_blast_audience(chat_blast.blast_id) USING (blast_id) + LEFT JOIN chat_member member_a on from_user_id = member_a.user_id + LEFT JOIN chat_member member_b on to_user_id = member_b.user_id and member_b.chat_id = member_a.chat_id + WHERE blast_id = $1 + AND member_b.chat_id IS NOT NULL + AND chat_allowed(from_user_id, to_user_id) + ), + insert_message AS ( + INSERT INTO chat_message + (message_id, chat_id, user_id, created_at, blast_id) + SELECT + blast_id || targ.chat_id, -- this ordering needs to match Misc.BlastMessageID + targ.chat_id, + targ.from_user_id, + $2, + blast_id + FROM targ + ON conflict do nothing + ) + SELECT chat_id FROM targ; + ` + + rows, err := tx.Query(ctx, fanOutSql, params.BlastID, ts) + if err != nil { + return nil, err + } + defer rows.Close() + + // Scan the results into the results slice + results, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (ChatBlastResult, error) { + var result ChatBlastResult + err := row.Scan(&result.ChatID, &result.ToUserID) + return result, err + }) + if err != nil { + return nil, err + } + + // Formulate chat rpc messages for recipients who have an existing chat with sender + var outgoingMessages []OutgoingChatMessage + for _, result := range results { + messageID := BlastMessageID(params.BlastID, result.ChatID) + + isPlaintext := true + outgoingMessages = append(outgoingMessages, OutgoingChatMessage{ + ChatMessageRPC: ChatMessageRPC{ + Method: MethodChatMessage, + Params: ChatMessageRPCParams{ + ChatID: result.ChatID, + Message: params.Message, + MessageID: messageID, + IsPlaintext: &isPlaintext, + Audience: ¶ms.Audience, + }}}) + + if err := chatUpdateLatestFields(tx, ctx, result.ChatID); err != nil { + return nil, err + } + } + + return outgoingMessages, nil +} diff --git a/api/comms/chat_id.go b/api/comms/chat_id.go index 03605404..09a153c4 100644 --- a/api/comms/chat_id.go +++ b/api/comms/chat_id.go @@ -19,3 +19,8 @@ func ChatID(id1, id2 int) string { } return chatId } + +// Returns a unique Message ID for a blast message in a chat. +func BlastMessageID(blastID, chatID string) string { + return blastID + chatID +} diff --git a/api/comms/get_new_blasts.go b/api/comms/get_new_blasts.go deleted file mode 100644 index c1e9644c..00000000 --- a/api/comms/get_new_blasts.go +++ /dev/null @@ -1,175 +0,0 @@ -package comms - -import ( - "context" - "database/sql" - "time" - - "bridgerton.audius.co/api/dbv1" - "bridgerton.audius.co/trashid" - "github.com/jackc/pgx/v5" -) - -type BlastRow struct { - PendingChatID string `db:"-" json:"pending_chat_id"` - BlastID string `db:"blast_id" json:"blast_id"` - FromUserID int32 `db:"from_user_id" json:"from_user_id"` - Audience string `db:"audience" json:"audience"` - AudienceContentType sql.NullString `db:"audience_content_type" json:"audience_content_type"` - AudienceContentID sql.NullInt32 `db:"audience_content_id" json:"-"` - AudienceContentIDEncoded sql.NullString `db:"-" json:"audience_content_id"` - Plaintext string `db:"plaintext" json:"plaintext"` - CreatedAt time.Time `db:"created_at" json:"created_at"` -} - -type GetNewBlastsParams struct { - UserID int32 `db:"user_id" json:"user_id"` - ChatID string `db:"chat_id" json:"chat_id"` -} - -// Helper function to get new blasts, used by the mutation Validator -func GetNewBlasts(pool *dbv1.DBPools, ctx context.Context, arg GetNewBlastsParams) ([]BlastRow, error) { - - // this query is to find new blasts for the current user - // which don't already have a existing chat. - // see also: subtly different inverse query exists in chat_blast.go - // to fan out messages to existing chat - var findNewBlasts = ` - with - last_permission_change as ( - select max(t) as t from ( - select updated_at as t from chat_permissions where user_id = $1 - union - select created_at as t from chat_blocked_users where blocker_user_id = $1 - union - select to_timestamp(0) - ) as timestamp_subquery - ), - all_new as ( - select * - from chat_blast blast - where - from_user_id in ( - -- follower_audience - SELECT followee_user_id AS from_user_id - FROM follows - WHERE blast.audience = 'follower_audience' - AND follows.followee_user_id = blast.from_user_id - AND follows.follower_user_id = $1 - AND follows.is_delete = false - AND follows.created_at < blast.created_at - ) - OR from_user_id in ( - -- tipper_audience - SELECT receiver_user_id - FROM user_tips tip - WHERE blast.audience = 'tipper_audience' - AND receiver_user_id = blast.from_user_id - AND sender_user_id = $1 - AND tip.created_at < blast.created_at - ) - OR from_user_id IN ( - -- remixer_audience - SELECT og.owner_id - FROM tracks t - JOIN remixes ON remixes.child_track_id = t.track_id - JOIN tracks og ON remixes.parent_track_id = og.track_id - WHERE blast.audience = 'remixer_audience' - AND og.owner_id = blast.from_user_id - AND t.owner_id = $1 - AND ( - blast.audience_content_id IS NULL - OR ( - blast.audience_content_type = 'track' - AND blast.audience_content_id = og.track_id - ) - ) - ) - OR from_user_id IN ( - -- customer_audience - SELECT seller_user_id - FROM usdc_purchases p - WHERE blast.audience = 'customer_audience' - AND p.seller_user_id = blast.from_user_id - AND p.buyer_user_id = $1 - AND ( - audience_content_id IS NULL - OR ( - blast.audience_content_type = p.content_type::text - AND blast.audience_content_id = p.content_id - ) - ) - ) - OR from_user_id IN ( - -- coin_holder_audience via sol_user_balances - SELECT ac.user_id - FROM artist_coins ac - JOIN sol_user_balances sub ON sub.mint = ac.mint - WHERE blast.audience = 'coin_holder_audience' - AND ac.user_id = blast.from_user_id - AND sub.user_id = $1 - AND sub.balance > 0 - -- TODO: PE-6663 This isn't entirely correct yet, need to check "time of most recent membership" - AND sub.created_at < blast.created_at - ) - ) - select * from all_new - where created_at > (select t from last_permission_change) - and chat_allowed(from_user_id, $1) - order by created_at - ` - - rows, err := pool.Query(ctx, findNewBlasts, arg.UserID) - if err != nil { - return nil, err - } - - items, err := pgx.CollectRows(rows, pgx.RowToStructByName[BlastRow]) - if err != nil { - return nil, err - } - - for idx, blastRow := range items { - chatId := ChatID(int(arg.UserID), int(blastRow.FromUserID)) - items[idx].PendingChatID = chatId - - if blastRow.AudienceContentID.Valid { - encoded, _ := trashid.EncodeHashId(int(blastRow.AudienceContentID.Int32)) - items[idx].AudienceContentIDEncoded = sql.NullString{ - String: encoded, - Valid: true, - } - } - } - - rows, err = pool.Query(ctx, `select chat_id from chat_member where user_id = $1`, arg.UserID) - if err != nil { - return nil, err - } - - existingChatIdList, err := pgx.CollectRows(rows, pgx.RowTo[string]) - if err != nil { - return nil, err - } - - existingChatIds := map[string]bool{} - for _, id := range existingChatIdList { - existingChatIds[id] = true - } - - // filter out blast rows where chatIds is taken - filtered := make([]BlastRow, 0, len(items)) - for _, item := range items { - if existingChatIds[item.PendingChatID] { - continue - } - // allow caller to filter to blasts for a given chat ID - if arg.ChatID != "" && item.PendingChatID != arg.ChatID { - continue - } - filtered = append(filtered, item) - } - - return filtered, err - -} diff --git a/api/comms/apply.go b/api/comms/rpc_processor.go similarity index 52% rename from api/comms/apply.go rename to api/comms/rpc_processor.go index cf1f2031..f9b2acb7 100644 --- a/api/comms/apply.go +++ b/api/comms/rpc_processor.go @@ -9,29 +9,24 @@ import ( "time" "bridgerton.audius.co/api/dbv1" + "bridgerton.audius.co/config" "bridgerton.audius.co/trashid" + "go.uber.org/zap" - // "comms.audius.co/discovery/config" - // "comms.audius.co/discovery/db" - // "comms.audius.co/discovery/db/queries" - // "comms.audius.co/discovery/misc" - // "comms.audius.co/discovery/schema" - // "github.com/jmoiron/sqlx" - + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" "github.com/tidwall/gjson" - // "gorm.io/gorm/logger" ) type RPCProcessor struct { sync.Mutex pool *dbv1.DBPools + writePool *pgxpool.Pool validator *Validator - - // TODO - discoveryConfig *config.DiscoveryConfig + logger *zap.Logger } -func NewProcessor(pool *dbv1.DBPools, discoveryConfig *config.DiscoveryConfig) (*RPCProcessor, error) { +func NewProcessor(pool *dbv1.DBPools, writePool *pgxpool.Pool, config *config.Config, logger *zap.Logger) (*RPCProcessor, error) { // set up validator + limiter limiter, err := NewRateLimiter() @@ -39,24 +34,13 @@ func NewProcessor(pool *dbv1.DBPools, discoveryConfig *config.DiscoveryConfig) ( return nil, err } - aaoServer := "https://discoveryprovider.audius.co" - if discoveryConfig.IsStaging { - aaoServer = "https://discoveryprovider.staging.audius.co" - } - - if discoveryConfig.IsDev { - aaoServer = "http://audius-protocol-discovery-provider-1" - } - - validator := &Validator{ - pool: pool, - limiter: limiter, - aaoServer: aaoServer, - } + validator := NewValidator(pool, limiter, config, logger) proc := &RPCProcessor{ - validator: validator, - discoveryConfig: discoveryConfig, + validator: validator, + pool: pool, + writePool: writePool, + logger: logger, } return proc, nil @@ -65,20 +49,25 @@ func NewProcessor(pool *dbv1.DBPools, discoveryConfig *config.DiscoveryConfig) ( // TODO: replace logger // Clean up wallet recovery (do we even need it?) // Change format or at least naming of RpcLog +// TODO: logger with() functionality to track details of current rpc message // Do we still need to check for already applied? // - Maybe the validation needs to happen inside a transaction, since we check for existing stuff there? // Validates + applies a message -func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { - - logger := slog.With("sig", rpcLog.Sig) +func (proc *RPCProcessor) Apply(ctx context.Context, rpcLog *RpcLog) error { + logger := proc.logger.With( + zap.String("sig", rpcLog.Sig), + // TODO: audit + // zap.String("from_wallet", rpcLog.FromWallet), + // zap.String("relayed_by", rpcLog.RelayedBy), + // zap.Time("relayed_at", rpcLog.RelayedAt), + ) var err error - // check for already applied var exists int - db.Conn.Get(&exists, `select count(*) from rpc_log where sig = $1`, rpcLog.Sig) + proc.pool.QueryRow(ctx, `select count(*) from rpc_log where sig = $1`, rpcLog.Sig).Scan(&exists) if exists == 1 { - logger.Debug("rpc already in log, skipping duplicate", "sig", rpcLog.Sig) + logger.Debug("rpc already in log, skipping duplicate", zap.String("sig", rpcLog.Sig)) return nil } @@ -90,15 +79,15 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { } // validate signing wallet - wallet, err := misc.RecoverWallet(rpcLog.Rpc, rpcLog.Sig) + wallet, err := recoverSigningWallet(rpcLog.Sig, rpcLog.Rpc) if err != nil { logger.Warn("unable to recover wallet, skipping") return nil } - logger.Debug("recovered wallet", "took", takeSplit()) + logger.Debug("recovered wallet", zap.Duration("took", takeSplit())) if wallet != rpcLog.FromWallet { - logger.Warn("recovered wallet no match", "recovered", wallet, "expected", rpcLog.FromWallet, "realeyd_by", rpcLog.RelayedBy) + logger.Warn("recovered wallet no match", zap.String("recovered", wallet), zap.String("expected", rpcLog.FromWallet), zap.String("realeyd_by", rpcLog.RelayedBy)) return nil } @@ -112,19 +101,14 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { // check for "internal" message... if strings.HasPrefix(rawRpc.Method, "internal.") { - err := proc.applyInternalMessage(rpcLog, &rawRpc) - if err != nil { - logger.Info("failed to apply internal rpc", "error", err) - } else { - logger.Info("applied internal RPC", "sig", rpcLog.Sig) - } + logger.Warn("recieved internal message, skipping") return nil } // get ts messageTs := rpcLog.RelayedAt - userId, err := GetRPCCurrentUserID(rpcLog, &rawRpc) + userId, err := proc.GetRPCCurrentUserID(ctx, rpcLog, &rawRpc) if err != nil { logger.Info("unable to get user ID") return err // or nil? @@ -134,24 +118,25 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { chatId := gjson.GetBytes(rpcLog.Rpc, "params.chat_id").String() logger = logger.With( - "wallet", wallet, - "userId", userId, - "relayed_by", rpcLog.RelayedBy, - "relayed_at", rpcLog.RelayedAt, - "chat_id", chatId, - "sig", rpcLog.Sig) - logger.Debug("got user", "took", takeSplit()) + zap.String("wallet", wallet), + zap.Int32("userId", userId), + zap.String("relayed_by", rpcLog.RelayedBy), + zap.Time("relayed_at", rpcLog.RelayedAt), + zap.String("chat_id", chatId), + zap.String("sig", rpcLog.Sig), + ) + logger.Debug("got user", zap.Duration("took", takeSplit())) attemptApply := func() error { // write to db - tx, err := db.Conn.Beginx() + tx, err := proc.writePool.Begin(ctx) if err != nil { return err } - defer tx.Rollback() + defer tx.Rollback(ctx) - logger.Debug("begin tx", "took", takeSplit(), "sig", rpcLog.Sig) + logger.Debug("begin tx", zap.Duration("took", takeSplit()), zap.String("sig", rpcLog.Sig)) switch RPCMethod(rawRpc.Method) { case RPCMethodChatCreate: @@ -160,7 +145,7 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } - err = chatCreate(tx, userId, messageTs, params) + err = chatCreate(tx, ctx, userId, messageTs, params) if err != nil { return err } @@ -170,7 +155,7 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } - err = chatDelete(tx, userId, params.ChatID, messageTs) + err = chatDelete(tx, ctx, userId, params.ChatID, messageTs) if err != nil { return err } @@ -180,7 +165,7 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } - err = chatSendMessage(tx, userId, params.ChatID, params.MessageID, messageTs, params.Message) + err = chatSendMessage(tx, ctx, userId, params.ChatID, params.MessageID, messageTs, params.Message) if err != nil { return err } @@ -190,7 +175,7 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } - err = chatReactMessage(tx, userId, params.ChatID, params.MessageID, params.Reaction, messageTs) + err = chatReactMessage(tx, ctx, userId, params.ChatID, params.MessageID, params.Reaction, messageTs) if err != nil { return err } @@ -201,16 +186,19 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } + // do nothing if last active at >= message timestamp - lastActive, err := queries.LastActiveAt(tx, context.Background(), queries.ChatMembershipParams{ - ChatID: params.ChatID, - UserID: userId, - }) + var lastActive pgtype.Timestamp + const lastActiveAtQuery = ` +select last_active_at from chat_member where chat_id = $1 and user_id = $2` + + err = tx.QueryRow(ctx, lastActiveAtQuery, params.ChatID, userId).Scan(&lastActive) if err != nil { return err } + if !lastActive.Valid || messageTs.After(lastActive.Time) { - err = chatReadMessages(tx, userId, params.ChatID, messageTs) + err = chatReadMessages(tx, ctx, userId, params.ChatID, messageTs) if err != nil { return err } @@ -221,7 +209,7 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } - err = chatSetPermissions(tx, userId, params.Permit, params.PermitList, params.Allow, messageTs) + err = chatSetPermissions(tx, ctx, userId, params.Permit, params.PermitList, params.Allow, messageTs) if err != nil { return err } @@ -235,7 +223,7 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } - err = chatBlock(tx, userId, int32(blockeeUserId), messageTs) + err = chatBlock(tx, ctx, userId, int32(blockeeUserId), messageTs) if err != nil { return err } @@ -249,7 +237,7 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { if err != nil { return err } - err = chatUnblock(tx, userId, int32(unblockedUserId), messageTs) + err = chatUnblock(tx, ctx, userId, int32(unblockedUserId), messageTs) if err != nil { return err } @@ -261,13 +249,13 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { return err } - outgoingMessages, err := chatBlast(tx, userId, messageTs, params) + outgoingMessages, err := chatBlast(tx, ctx, userId, messageTs, params) if err != nil { return err } // Send chat message websocket event to all recipients who have existing chats for _, outgoingMessage := range outgoingMessages { - j, err := json.Marshal(outgoingMessage.ChatMessageRPC) + _, err := json.Marshal(outgoingMessage.ChatMessageRPC) if err != nil { slog.Error("err: invalid json", "err", err) } else { @@ -276,32 +264,79 @@ func (proc *RPCProcessor) Apply(rpcLog *RpcLog) error { } } default: - logger.Warn("no handler for ", rawRpc.Method) + logger.Warn("no handler for ", zap.String("method", rawRpc.Method)) } - logger.Debug("called handler", "took", takeSplit()) + logger.Debug("called handler", zap.Duration("took", takeSplit())) - err = tx.Commit() + err = tx.Commit(ctx) if err != nil { return err } - logger.Debug("commited", "took", takeSplit()) + logger.Debug("commited", zap.Duration("took", takeSplit())) // TODO // send out websocket events fire + forget style // websocketNotify(rpcLog.Rpc, userId, messageTs.Round(time.Microsecond)) - logger.Debug("websocket push done", "took", takeSplit()) + logger.Debug("websocket push done", zap.Duration("took", takeSplit())) return nil } err = attemptApply() if err != nil { - logger.Warn("apply failed", "err", err) + logger.Warn("apply failed", zap.Error(err)) } return err } +func (proc *RPCProcessor) GetRPCCurrentUserID(ctx context.Context, rpcLog *RpcLog, rawRpc *RawRPC) (int32, error) { + walletAddress := rpcLog.FromWallet + encodedCurrentUserId := rawRpc.CurrentUserID + + // attempt to read the (newly added) current_user_id field + if encodedCurrentUserId != "" { + if u, err := trashid.DecodeHashId(encodedCurrentUserId); err == nil && u > 0 { + + const checkCurrentUserQuery = ` + select count(*) > 0 + from users + where is_current = true + and user_id = $1 + and wallet = lower($2) + and handle IS NOT NULL + and is_available = TRUE + and is_deactivated = FALSE + ` + // valid current_user_id + wallet combo? + // for now just check that the pair exists in the user table + // in the future this can check a "grants" table that a given operation is permitted + isValid := false + err := proc.pool.QueryRow(ctx, checkCurrentUserQuery, u, walletAddress).Scan(&isValid) + if err == nil && isValid { + return int32(u), nil + } else { + proc.logger.Warn("invalid current_user_id", zap.Error(err), zap.String("wallet", walletAddress), zap.String("encoded_user_id", encodedCurrentUserId), zap.Int("user_id", u)) + } + } + } + + const getUserIDFromWalletQuery = ` + select user_id + from users + where is_current = TRUE + and handle IS NOT NULL + and is_available = TRUE + and is_deactivated = FALSE + and wallet = LOWER($1) + order by user_id asc + ` + // fallback to looking up user_id using wallet alone + var userId int32 + err := proc.pool.QueryRow(ctx, getUserIDFromWalletQuery, walletAddress).Scan(&userId) + return userId, err +} + // func websocketNotify(rpcJson json.RawMessage, userId int32, timestamp time.Time) { // if chatId := gjson.GetBytes(rpcJson, "params.chat_id").String(); chatId != "" { diff --git a/api/comms/schema.go b/api/comms/schema.go index 25a4134a..3aac7523 100644 --- a/api/comms/schema.go +++ b/api/comms/schema.go @@ -389,13 +389,13 @@ const ( type ChatPermission string const ( - All ChatPermission = "all" - Followees ChatPermission = "followees" - Followers ChatPermission = "followers" - None ChatPermission = "none" - Tippees ChatPermission = "tippees" - Tippers ChatPermission = "tippers" - Verified ChatPermission = "verified" + ChatPermissionAll ChatPermission = "all" + ChatPermissionFollowees ChatPermission = "followees" + ChatPermissionFollowers ChatPermission = "followers" + ChatPermissionNone ChatPermission = "none" + ChatPermissionTippees ChatPermission = "tippees" + ChatPermissionTippers ChatPermission = "tippers" + ChatPermissionVerified ChatPermission = "verified" ) type RPCMethod string diff --git a/api/comms/validator.go b/api/comms/validator.go index 48a9357c..5d9db45d 100644 --- a/api/comms/validator.go +++ b/api/comms/validator.go @@ -3,7 +3,6 @@ package comms import ( "bytes" "context" - "database/sql" "encoding/json" "errors" "fmt" @@ -12,7 +11,9 @@ import ( "time" "bridgerton.audius.co/api/dbv1" + "bridgerton.audius.co/config" "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5" "go.uber.org/zap" ) @@ -27,11 +28,12 @@ type Validator struct { aaoServer string } -func NewValidator(pool *dbv1.DBPools, limiter *RateLimiter, aaoServer string) *Validator { +func NewValidator(pool *dbv1.DBPools, limiter *RateLimiter, config *config.Config, logger *zap.Logger) *Validator { return &Validator{ pool: pool, limiter: limiter, - aaoServer: aaoServer, + aaoServer: config.AAOServer, + logger: logger, } } @@ -378,27 +380,27 @@ func (vtor *Validator) validateNewMessageRateLimit(pool *dbv1.DBPools, ctx conte where user_id = $1 and created_at > now() - interval '60 seconds'; ` - var s1, s10, s60 sql.NullInt64 + var s1, s10, s60 int64 err = pool.QueryRow(ctx, query, userId).Scan(&s1, &s10, &s60) if err != nil { slog.Error("burst rate limit query failed", "err", err) } // 10 per second in last second - if s1.Int64 > 10 { + if s1 > 10 { slog.Warn("message rate limit exceeded", "bucket", "1s", "user_id", userId, "count", s1) return ErrMessageRateLimitExceeded } // 7 per second for last 10 seconds - if s10.Int64 > 70 { + if s10 > 70 { slog.Warn("message rate limit exceeded", "bucket", "10s", "user_id", userId, "count", s10) return ErrMessageRateLimitExceeded } // 5 per second for last 60 seconds - if s60.Int64 > 300 { + if s60 > 300 { slog.Warn("message rate limit exceeded", "bucket", "60s", "user_id", userId, "count", s60) return ErrMessageRateLimitExceeded } @@ -518,7 +520,7 @@ func validatePermittedToMessage(pool *dbv1.DBPools, ctx context.Context, userId var isPermitted bool err := pool.QueryRow(ctx, query, chatId, userId).Scan(&isPermitted) if err != nil { - if err == sql.ErrNoRows { + if err == pgx.ErrNoRows { return errors.New("Chat must have 2 members") } return err @@ -538,7 +540,7 @@ func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, use var handle string err := pool.QueryRow(ctx, `SELECT handle FROM users WHERE user_id = $1`, userId).Scan(&handle) if err != nil { - if err == sql.ErrNoRows { + if err == pgx.ErrNoRows { return fmt.Errorf("user %d not found", userId) } return err diff --git a/config/config.go b/config/config.go index b6933e16..c299ff56 100644 --- a/config/config.go +++ b/config/config.go @@ -38,6 +38,7 @@ type Config struct { BirdeyeToken string SolanaIndexerWorkers int SolanaIndexerRetryInterval time.Duration + AAOServer string } var Cfg = Config{ @@ -88,6 +89,7 @@ func init() { Cfg.PythonUpstreams = []string{ "http://audius-protocol-discovery-provider-1", } + Cfg.AAOServer = "http://audius-protocol-discovery-provider-1" case "stage": fallthrough case "staging": @@ -103,6 +105,7 @@ func init() { Cfg.Rewards = core_config.MakeRewards(core_config.StageClaimAuthorities, core_config.StageRewardExtensions) Cfg.AudiusdURL = "creatornode11.staging.audius.co" Cfg.ChainId = "audius-testnet-alpha" + Cfg.AAOServer = "https://discoveryprovider.staging.audius.co" case "prod": fallthrough case "production": @@ -120,6 +123,7 @@ func init() { Cfg.Rewards = core_config.MakeRewards(core_config.ProdClaimAuthorities, core_config.ProdRewardExtensions) Cfg.AudiusdURL = "creatornode.audius.co" Cfg.ChainId = "audius-mainnet-alpha-beta" + Cfg.AAOServer = "https://discoveryprovider.audius.co" default: log.Fatalf("Unknown environment: %s", env) } From 7cad3555327fa3d2045054c9d13d1a3228d9821a Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:47:23 -0400 Subject: [PATCH 3/7] finish wiring up mutate endpoint --- api/comms/chat.go | 4 +- api/comms/chat_blast.go | 2 +- api/comms/chat_id.go | 26 ----- api/comms/constants.go | 2 +- api/comms/rate_limit.go | 4 +- api/comms/rate_limit_test.go | 202 ----------------------------------- api/comms/rpc_processor.go | 16 +-- api/comms/validator.go | 35 ++---- api/comms_mutate.go | 43 +++----- api/server.go | 8 ++ trashid/chatid.go | 6 ++ 11 files changed, 50 insertions(+), 298 deletions(-) delete mode 100644 api/comms/chat_id.go delete mode 100644 api/comms/rate_limit_test.go diff --git a/api/comms/chat.go b/api/comms/chat.go index abcfb60d..fddab589 100644 --- a/api/comms/chat.go +++ b/api/comms/chat.go @@ -76,7 +76,7 @@ func chatCreate(tx pgx.Tx, ctx context.Context, userId int32, ts time.Time, para values ($1, $2, $3, $4, $5) on conflict do nothing - `, BlastMessageID(blast.BlastID, params.ChatID), params.ChatID, blast.FromUserID, blast.CreatedAt, blast.BlastID) + `, trashid.BlastMessageID(blast.BlastID, params.ChatID), params.ChatID, blast.FromUserID, blast.CreatedAt, blast.BlastID) if err != nil { return err } @@ -411,7 +411,7 @@ func getNewBlasts(tx pgx.Tx, ctx context.Context, arg getNewBlastsParams) ([]bla } for idx, blastRow := range items { - chatId := ChatID(int(arg.UserID), int(blastRow.FromUserID)) + chatId := trashid.ChatID(int(arg.UserID), int(blastRow.FromUserID)) items[idx].PendingChatID = chatId if blastRow.AudienceContentID.Valid { diff --git a/api/comms/chat_blast.go b/api/comms/chat_blast.go index e71393d4..5ee5862a 100644 --- a/api/comms/chat_blast.go +++ b/api/comms/chat_blast.go @@ -97,7 +97,7 @@ func chatBlast(tx pgx.Tx, ctx context.Context, userId int32, ts time.Time, param // Formulate chat rpc messages for recipients who have an existing chat with sender var outgoingMessages []OutgoingChatMessage for _, result := range results { - messageID := BlastMessageID(params.BlastID, result.ChatID) + messageID := trashid.BlastMessageID(params.BlastID, result.ChatID) isPlaintext := true outgoingMessages = append(outgoingMessages, OutgoingChatMessage{ diff --git a/api/comms/chat_id.go b/api/comms/chat_id.go deleted file mode 100644 index 09a153c4..00000000 --- a/api/comms/chat_id.go +++ /dev/null @@ -1,26 +0,0 @@ -package comms - -import ( - "fmt" - - "bridgerton.audius.co/trashid" -) - -// ChatID return a encodedUser1:encodedUser2 ID where encodedUser1 is < encodedUser2 -// which is the convention used to make chat IDs deterministic. -// See makeChatId in SDK: packages/common/src/store/pages/chat/utils.ts -func ChatID(id1, id2 int) string { - // TODO: Handle errors - user1IdEncoded, _ := trashid.EncodeHashId(id1) - user2IdEncoded, _ := trashid.EncodeHashId(id2) - chatId := fmt.Sprintf("%s:%s", user1IdEncoded, user2IdEncoded) - if user2IdEncoded < user1IdEncoded { - chatId = fmt.Sprintf("%s:%s", user2IdEncoded, user1IdEncoded) - } - return chatId -} - -// Returns a unique Message ID for a blast message in a chat. -func BlastMessageID(blastID, chatID string) string { - return blastID + chatID -} diff --git a/api/comms/constants.go b/api/comms/constants.go index 810f0ebb..6e642159 100644 --- a/api/comms/constants.go +++ b/api/comms/constants.go @@ -1,10 +1,10 @@ package comms var ( + // TODO: verify this is correct SigHeader = "x-sig" SignatureTimeToLiveMs = int64(1000 * 60 * 60 * 12) // 12 hours - // TODO: Do we need these configurable? // Rate limit config RateLimitRulesBucketName = "rateLimitRules" RateLimitTimeframeHours = "timeframeHours" diff --git a/api/comms/rate_limit.go b/api/comms/rate_limit.go index 95796b3a..5b940215 100644 --- a/api/comms/rate_limit.go +++ b/api/comms/rate_limit.go @@ -26,7 +26,5 @@ func (limiter *RateLimiter) Get(rule string) int { return val } - // TODO - // return config.DefaultRateLimitRules[rule] - return 0 + return DefaultRateLimitRules[rule] } diff --git a/api/comms/rate_limit_test.go b/api/comms/rate_limit_test.go deleted file mode 100644 index fd0d3220..00000000 --- a/api/comms/rate_limit_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package comms - -import ( - "fmt" - "math/rand" - "strconv" - "testing" - "time" - - "comms.audius.co/discovery/db" - "comms.audius.co/discovery/misc" - "comms.audius.co/discovery/schema" - "github.com/stretchr/testify/assert" -) - -func TestRateLimit(t *testing.T) { - - // todo: update for no-nats - t.Skip() - - var err error - - // reset tables under test - _, err = db.Conn.Exec("truncate table chat cascade") - assert.NoError(t, err) - - // Add test rules - // testRules := map[string]int{ - // config.RateLimitTimeframeHours: 24, - // config.RateLimitMaxNumMessages: 3, - // config.RateLimitMaxNumMessagesPerRecipient: 2, - // config.RateLimitMaxNumNewChats: 2, - // } - // for rule, limit := range testRules { - // _, err := kv.PutString(rule, strconv.Itoa(limit)) - // assert.NoError(t, err) - // } - - tx := db.Conn.MustBegin() - - seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) - user1Id := seededRand.Int31() - user2Id := seededRand.Int31() - user3Id := seededRand.Int31() - user4Id := seededRand.Int31() - user5Id := seededRand.Int31() - - user1IdEncoded, err := misc.EncodeHashId(int(user1Id)) - assert.NoError(t, err) - user3IdEncoded, err := misc.EncodeHashId(int(user3Id)) - assert.NoError(t, err) - user4IdEncoded, err := misc.EncodeHashId(int(user4Id)) - assert.NoError(t, err) - user5IdEncoded, err := misc.EncodeHashId(int(user5Id)) - assert.NoError(t, err) - - // user1Id created a new chat with user2Id 48 hours ago - chatId1 := strconv.Itoa(seededRand.Int()) - chatTs := time.Now().UTC().Add(-time.Hour * time.Duration(48)) - _, err = tx.Exec("insert into chat (chat_id, created_at, last_message_at) values ($1, $2, $2)", chatId1, chatTs) - assert.NoError(t, err) - _, err = tx.Exec("insert into chat_member (chat_id, invited_by_user_id, invite_code, user_id, created_at) values ($1, $2, $1, $2, $4), ($1, $2, $1, $3, $4)", chatId1, user1Id, user2Id, chatTs) - assert.NoError(t, err) - - // user1Id messaged user2Id 48 hours ago - err = chatSendMessage(tx, user1Id, chatId1, "1", chatTs, "Hello") - assert.NoError(t, err) - - // user1Id messages user2Id twice now - message := "Hello today 1" - messageRpc := schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId1, message)), - } - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - assert.NoError(t, err) - err = chatSendMessage(tx, user1Id, chatId1, "2", time.Now().UTC(), message) - assert.NoError(t, err) - message = "Hello today 2" - messageRpc = schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId1, message)), - } - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - assert.NoError(t, err) - err = chatSendMessage(tx, user1Id, chatId1, "3", time.Now().UTC(), message) - assert.NoError(t, err) - - // user1Id messages user2Id a 3rd time - // Blocked by rate limiter (hit max # messages per recipient in the past 24 hours) - message = "Hello again again." - messageRpc = schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId1, message)), - } - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - assert.ErrorContains(t, err, "User has exceeded the maximum number of new messages") - - // user1Id creates a new chat with user3Id (1 chat created in 24h) - chatId2 := strconv.Itoa(seededRand.Int()) - createRpc := schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "invites": [{"user_id": "%s", "invite_code": "%s"}, {"user_id": "%s", "invite_code": "%s"}]}`, chatId2, user1IdEncoded, chatId2, user3IdEncoded, chatId2)), - } - err = testValidator.validateChatCreate(tx, user1Id, createRpc) - assert.NoError(t, err) - SetupChatWithMembers(t, tx, chatId2, user1Id, user3Id) - - // user1Id messages user3Id - // Still blocked by rate limiter (hit max # messages with user2Id in the past 24h) - message = "Hi user3Id" - messageRpc = schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId2, message)), - } - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - assert.ErrorContains(t, err, "User has exceeded the maximum number of new messages") - - // Remove message 3 from db so can test other rate limits - _, err = tx.Exec("delete from chat_message where message_id = '3'") - assert.NoError(t, err) - - // user1Id should be able to message user3Id now - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - assert.NoError(t, err) - err = chatSendMessage(tx, user1Id, chatId2, "3", time.Now().UTC(), message) - assert.NoError(t, err) - - // user1Id creates a new chat with user4Id (2 chats created in 24h) - chatId3 := strconv.Itoa(seededRand.Int()) - createRpc = schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "invites": [{"user_id": "%s", "invite_code": "%s"}, {"user_id": "%s", "invite_code": "%s"}]}`, chatId3, user1IdEncoded, chatId3, user4IdEncoded, chatId3)), - } - err = testValidator.validateChatCreate(tx, user1Id, createRpc) - assert.NoError(t, err) - SetupChatWithMembers(t, tx, chatId3, user1Id, user4Id) - - // user1Id messages user4Id - message = "Hi user4Id again" - messageRpc = schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId3, message)), - } - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - assert.NoError(t, err) - err = chatSendMessage(tx, user1Id, chatId3, "4", time.Now().UTC(), message) - assert.NoError(t, err) - - // user1Id messages user4Id again (4th message to anyone in 24h) - // Blocked by rate limiter (hit max # messages in the past 24 hours) - message = "Hi user4Id again" - messageRpc = schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId3, message)), - } - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - assert.ErrorContains(t, err, "User has exceeded the maximum number of new messages") - - // user1Id creates a new chat with user5Id (3 chats created in 24h) - // Blocked by rate limiter (hit max # new chats) - chatId4 := strconv.Itoa(seededRand.Int()) - createRpc = schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "invites": [{"user_id": "%s", "invite_code": "%s"}, {"user_id": "%s", "invite_code": "%s"}]}`, chatId4, user1IdEncoded, chatId2, user5IdEncoded, chatId4)), - } - err = testValidator.validateChatCreate(tx, user1Id, createRpc) - assert.ErrorContains(t, err, "An invited user has exceeded the maximum number of new chats") - - tx.Rollback() -} - -func TestBurstRateLimit(t *testing.T) { - var err error - - // reset tables under test - _, err = db.Conn.Exec("truncate table chat cascade") - assert.NoError(t, err) - - tx := db.Conn.MustBegin() - defer tx.Rollback() - - seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) - chatId := strconv.Itoa(seededRand.Int()) - user1Id := seededRand.Int31() - user2Id := seededRand.Int31() - - SetupChatWithMembers(t, tx, chatId, user1Id, user2Id) - - // hit the 1 second limit... send a burst of messages - for i := 1; i < 16; i++ { - - message := fmt.Sprintf("burst %d", i) - err = chatSendMessage(tx, user1Id, chatId, message, time.Now().UTC(), message) - assert.NoError(t, err, "i is", i) - - messageRpc := schema.RawRPC{ - Params: []byte(fmt.Sprintf(`{"chat_id": "%s", "message": "%s"}`, chatId, message)), - } - err = testValidator.validateChatMessage(tx, user1Id, messageRpc) - - // first 10 messages are ok... - // and then the per-second rate limiter kicks in - if i <= 10 { - assert.NoError(t, err, "i is", i) - } else { - assert.ErrorIs(t, err, ErrMessageRateLimitExceeded, "i = ", i) - } - } - -} diff --git a/api/comms/rpc_processor.go b/api/comms/rpc_processor.go index f9b2acb7..90f6ad31 100644 --- a/api/comms/rpc_processor.go +++ b/api/comms/rpc_processor.go @@ -3,7 +3,6 @@ package comms import ( "context" "encoding/json" - "log/slog" "strings" "sync" "time" @@ -46,21 +45,14 @@ func NewProcessor(pool *dbv1.DBPools, writePool *pgxpool.Pool, config *config.Co return proc, nil } -// TODO: replace logger -// Clean up wallet recovery (do we even need it?) -// Change format or at least naming of RpcLog -// TODO: logger with() functionality to track details of current rpc message -// Do we still need to check for already applied? -// - Maybe the validation needs to happen inside a transaction, since we check for existing stuff there? +func (proc *RPCProcessor) Validate(ctx context.Context, userId int32, rawRpc RawRPC) error { + return proc.validator.Validate(ctx, userId, rawRpc) +} // Validates + applies a message func (proc *RPCProcessor) Apply(ctx context.Context, rpcLog *RpcLog) error { logger := proc.logger.With( zap.String("sig", rpcLog.Sig), - // TODO: audit - // zap.String("from_wallet", rpcLog.FromWallet), - // zap.String("relayed_by", rpcLog.RelayedBy), - // zap.Time("relayed_at", rpcLog.RelayedAt), ) var err error @@ -257,7 +249,7 @@ select last_active_at from chat_member where chat_id = $1 and user_id = $2` for _, outgoingMessage := range outgoingMessages { _, err := json.Marshal(outgoingMessage.ChatMessageRPC) if err != nil { - slog.Error("err: invalid json", "err", err) + logger.Error("err: invalid json", zap.Error(err)) } else { // TODO // websocketNotify(json.RawMessage(j), userId, messageTs.Round(time.Microsecond)) diff --git a/api/comms/validator.go b/api/comms/validator.go index 5d9db45d..569ed30b 100644 --- a/api/comms/validator.go +++ b/api/comms/validator.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "log/slog" "net/http" "time" @@ -40,16 +39,6 @@ func NewValidator(pool *dbv1.DBPools, limiter *RateLimiter, config *config.Confi func (vtor *Validator) Validate(ctx context.Context, userId int32, rawRpc RawRPC) error { methodName := RPCMethod(rawRpc.Method) - // actually skip timestamp check for now... - // POST endpoint will check for recency... - // but peer servers could get it later... - // and we don't want to skip message that's over a min old. - - // Always check timestamp - // if time.Now().UnixMilli()-rawRpc.Timestamp > sharedConfig.SignatureTimeToLiveMs { - // return errors.New("Invalid timestamp") - // } - // banned? isBanned, err := vtor.isBanned(ctx, userId) if err != nil { @@ -130,7 +119,7 @@ func (vtor *Validator) validateChatCreate(ctx context.Context, userId int32, rpc } // Check that the creator is non-abusive - err = validateSenderPassesAbuseCheck(vtor.pool, ctx, userId, vtor.aaoServer) + err = validateSenderPassesAbuseCheck(vtor.pool, ctx, vtor.logger, userId, vtor.aaoServer) if err != nil { return err } @@ -299,9 +288,6 @@ func (vtor *Validator) validateChatUnblock(userId int32, rpc RawRPC) error { if !exists { return errors.New("user is not blocked") } - if err != nil { - return err - } return nil } @@ -383,25 +369,25 @@ func (vtor *Validator) validateNewMessageRateLimit(pool *dbv1.DBPools, ctx conte var s1, s10, s60 int64 err = pool.QueryRow(ctx, query, userId).Scan(&s1, &s10, &s60) if err != nil { - slog.Error("burst rate limit query failed", "err", err) + vtor.logger.Error("burst rate limit query failed", zap.Error(err)) } // 10 per second in last second if s1 > 10 { - slog.Warn("message rate limit exceeded", "bucket", "1s", "user_id", userId, "count", s1) + vtor.logger.Warn("message rate limit exceeded", zap.String("bucket", "1s"), zap.Int32("user_id", userId), zap.Int64("count", s1)) return ErrMessageRateLimitExceeded } // 7 per second for last 10 seconds if s10 > 70 { - slog.Warn("message rate limit exceeded", "bucket", "10s", "user_id", userId, "count", s10) + vtor.logger.Warn("message rate limit exceeded", zap.String("bucket", "10s"), zap.Int32("user_id", userId), zap.Int64("count", s10)) return ErrMessageRateLimitExceeded } // 5 per second for last 60 seconds if s60 > 300 { - slog.Warn("message rate limit exceeded", "bucket", "60s", "user_id", userId, "count", s60) + vtor.logger.Warn("message rate limit exceeded", zap.String("bucket", "60s"), zap.Int32("user_id", userId), zap.Int64("count", s60)) return ErrMessageRateLimitExceeded } } @@ -535,7 +521,8 @@ func validatePermittedToMessage(pool *dbv1.DBPools, ctx context.Context, userId var ErrAttestationFailed = errors.New("attestation failed") -func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, userId int32, aaoServer string) error { +// TODO: Better AAO usage that corresponds to the claim rewards code +func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, logger *zap.Logger, userId int32, aaoServer string) error { // Keeping this somewhat opaque as it gets sent to client var handle string err := pool.QueryRow(ctx, `SELECT handle FROM users WHERE user_id = $1`, userId).Scan(&handle) @@ -551,23 +538,23 @@ func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, use requestBody := []byte(`{ "challengeId": "x", "challengeSpecifier": "x", "amount": 0 }`) resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) if err != nil { - slog.Error("Error checking user attestation", "error", err, "handle", handle) + logger.Error("Error checking user attestation", zap.Error(err), zap.String("handle", handle)) return err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - slog.Warn("User failed AAO check", "userId", userId, "status", resp.StatusCode) + logger.Warn("User failed AAO check", zap.Int32("userId", userId), zap.Int("status", resp.StatusCode)) return ErrAttestationFailed } return nil } -// HasNewBlastFromUser efficiently checks if a new blast exists from a specific user +// hasNewBlastFromUser efficiently checks if a new blast exists from a specific user // without fetching all blast data. Returns true if a valid blast exists, false otherwise. func hasNewBlastFromUser(pool *dbv1.DBPools, ctx context.Context, userID int32, fromUserID int32) (bool, error) { // Construct the expected chat ID for this user pair - expectedChatID := ChatID(int(userID), int(fromUserID)) + expectedChatID := trashid.ChatID(int(userID), int(fromUserID)) // This query checks for the existence of a new blast from a specific user // using the same logic as GetNewBlasts but optimized for existence check diff --git a/api/comms_mutate.go b/api/comms_mutate.go index b64b54f9..783e3f3b 100644 --- a/api/comms_mutate.go +++ b/api/comms_mutate.go @@ -3,18 +3,14 @@ package api import ( "encoding/json" "errors" + "time" comms "bridgerton.audius.co/api/comms" - "github.com/AudiusProject/audiusd/pkg/logger" "github.com/gofiber/fiber/v2" + "go.uber.org/zap" ) -/* TODO List -- Decide if we want to validate first and then apply or validate inside apply within a transaction -- Attach instance of rpc processor or validator to the server? -*/ - func (app *ApiServer) mutateChat(c *fiber.Ctx) error { payload, wallet, err := comms.ReadSignedPost(c) if err != nil { @@ -28,40 +24,33 @@ func (app *ApiServer) mutateChat(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusBadRequest, "bad request: "+err.Error()) } - // - // rpcLog := &schema.RpcLog{ - // RelayedBy: s.config.MyHost, - // RelayedAt: time.Now(), - // FromWallet: wallet, - // Rpc: payload, - // Sig: c.Request().Header.Get(signing.SigHeader), - // } + rpcLog := &comms.RpcLog{ + RelayedBy: "bridge", + RelayedAt: time.Now(), + FromWallet: wallet, + Rpc: payload, + Sig: c.Get(comms.SigHeader), + } - // authedWallet := app.getAuthedWallet(c) userId, err := app.getUserIDFromWallet(c.Context(), wallet) if err != nil { return err } - // userId, err := rpcz.GetRPCCurrentUserID(rpcLog, &rawRpc) - // if err != nil { - // return c.String(400, "wallet not found: "+err.Error()) - // } - - // call validator - err = s.proc.Validate(userId, rawRpc) + // TODO: Decide if we want to validate first and then apply or validate inside apply within a transaction + err = app.commsRpcProcessor.Validate(c.Context(), int32(userId), rawRpc) if err != nil { - if errors.Is(err, rpcz.ErrAttestationFailed) { + if errors.Is(err, comms.ErrAttestationFailed) { return c.JSON(403, "bad request: "+err.Error()) } return c.JSON(400, "bad request: "+err.Error()) } - ok, err := s.proc.Apply(rpcLog) + err = app.commsRpcProcessor.Apply(c.Context(), rpcLog) if err != nil { - logger.Warn(string(payload), "wallet", wallet, "err", err) + app.logger.Warn("comms rpc apply failed", zap.String("payload", string(payload)), zap.String("wallet", wallet), zap.Error(err)) return err } - logger.Debug(string(payload), "wallet", wallet, "relay", true) - return c.JSON(200, ok) + app.logger.Debug("comms rpc apply succeeded", zap.String("payload", string(payload)), zap.String("wallet", wallet), zap.Bool("relay", true)) + return c.JSON(200) } diff --git a/api/server.go b/api/server.go index ad1c4a15..d1f72d1b 100644 --- a/api/server.go +++ b/api/server.go @@ -12,6 +12,7 @@ import ( "time" "bridgerton.audius.co/api/birdeye" + comms "bridgerton.audius.co/api/comms" "bridgerton.audius.co/api/dbv1" "bridgerton.audius.co/config" "bridgerton.audius.co/esindexer" @@ -163,6 +164,11 @@ func NewApiServer(config config.Config) *ApiServer { metricsCollector = NewMetricsCollector(logger, writePool) } + commsRpcProcessor, err := comms.NewProcessor(pool, writePool, &config, logger) + if err != nil { + panic(err) + } + app := &ApiServer{ App: fiber.New(fiber.Config{ JSONEncoder: json.Marshal, @@ -171,6 +177,7 @@ func NewApiServer(config config.Config) *ApiServer { ReadBufferSize: 32_768, UnescapePath: true, }), + commsRpcProcessor: commsRpcProcessor, env: config.Env, skipAuthCheck: skipAuthCheck, pool: pool, @@ -555,6 +562,7 @@ type BirdeyeClient interface { type ApiServer struct { *fiber.App + commsRpcProcessor *comms.RPCProcessor pool *dbv1.DBPools writePool *pgxpool.Pool queries *dbv1.Queries diff --git a/trashid/chatid.go b/trashid/chatid.go index 2109da2e..6716f4aa 100644 --- a/trashid/chatid.go +++ b/trashid/chatid.go @@ -2,6 +2,7 @@ package trashid import "fmt" +// TODO: Handle errors here if we can't encode the IDs // ChatID return a encodedUser1:encodedUser2 ID where encodedUser1 is < encodedUser2 // which is the convention used to make chat IDs deterministic. // See makeChatId in SDK: packages/common/src/store/pages/chat/utils.ts @@ -14,3 +15,8 @@ func ChatID(id1, id2 int) string { } return chatId } + +// Returns a unique Message ID for a blast message in a chat. +func BlastMessageID(blastID, chatID string) string { + return blastID + chatID +} From 22e89b795046d67680b4b7d9cef602c5cf35f916 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Tue, 19 Aug 2025 17:04:39 -0400 Subject: [PATCH 4/7] cleanup --- api/comms/rpc_log.go | 9 +++--- api/comms/rpc_processor.go | 2 +- api/comms/signed_request.go | 59 ------------------------------------- api/comms/validator.go | 12 -------- 4 files changed, 5 insertions(+), 77 deletions(-) diff --git a/api/comms/rpc_log.go b/api/comms/rpc_log.go index 710c4a0d..2af4eb62 100644 --- a/api/comms/rpc_log.go +++ b/api/comms/rpc_log.go @@ -5,11 +5,10 @@ import ( "time" ) -// RpcLog is passed around between servers -// It has the rpc and sig (from header) -// The relay server that receives the RPC will stamp it with relayed_by and relayed_at -// relayed_by and relayed_at are used so that peer servers can consume a http-feeds style event feed -// from every peer +// RpcLog was previously used to track messages sent between comms peer servers. +// We are now using it as a record of RPC requests received from clients. +// RelayedAt will be the timestamp when the server received the request. +// RelayedBy will be hard-coded to "bridge" to differentiate it from legacy rpclog messages. type RpcLog struct { ID string `db:"id" json:"id"` RelayedAt time.Time `db:"relayed_at" json:"relayed_at"` diff --git a/api/comms/rpc_processor.go b/api/comms/rpc_processor.go index 90f6ad31..c5b439a8 100644 --- a/api/comms/rpc_processor.go +++ b/api/comms/rpc_processor.go @@ -91,7 +91,7 @@ func (proc *RPCProcessor) Apply(ctx context.Context, rpcLog *RpcLog) error { return nil } - // check for "internal" message... + // check for "internal" message, which are from the legacy implementation if strings.HasPrefix(rawRpc.Method, "internal.") { logger.Warn("recieved internal message, skipping") return nil diff --git a/api/comms/signed_request.go b/api/comms/signed_request.go index ade07d09..18b451f5 100644 --- a/api/comms/signed_request.go +++ b/api/comms/signed_request.go @@ -8,65 +8,6 @@ import ( "github.com/gofiber/fiber/v2" ) -// func userIdForSignedGet(c echo.Context) (int32, error) { -// if c.Request().Method != "GET" { -// return 0, errors.New("readSignedGet: bad method: " + c.Request().Method) -// } - -// sigBase64 := c.Request().Header.Get(signing.SigHeader) - -// // for websocket request, read from query param instead of header -// if querySig := c.QueryParam("signature"); sigBase64 == "" && querySig != "" { -// sigBase64 = querySig -// } - -// // helper function to log error if present -// logError := func(err error) (int32, error) { -// if err != nil { -// slog.Warn("ReadSignedRequest error", -// "err", err, -// "url", c.Request().URL.String(), -// "sig", sigBase64) -// } -// return 0, err -// } - -// // Check that timestamp is not too old -// timestamp, err := strconv.ParseInt(c.QueryParam("timestamp"), 0, 64) -// if err != nil { -// return logError(err) -// } - -// tsAge := time.Now().UnixMilli() - timestamp -// if tsAge < 0 { -// tsAge *= -1 -// } -// if tsAge > signing.SignatureTimeToLiveMs { -// return logError(errors.New("timestamp not current")) -// } - -// // Strip out the app_name and api_key query parameters to get the true signature payload -// u := *c.Request().URL -// q := u.Query() -// q.Del("app_name") -// q.Del("api_key") -// q.Del("signature") -// u.RawQuery = q.Encode() -// payload := []byte(u.String()) - -// wallet, err := recoverSigningWallet(sigBase64, payload) -// if err != nil { -// return logError(errors.New("failed to recoverSigningWallet")) -// } - -// userId, err := queries.GetUserIDFromWallet(db.Conn, c.Request().Context(), wallet, c.QueryParam("current_user_id")) -// if err != nil { -// return logError(fmt.Errorf("failed to get user_id for wallet: %s", wallet)) -// } - -// return userId, nil -// } - func ReadSignedPost(c *fiber.Ctx) ([]byte, string, error) { if c.Method() != "POST" { return nil, "", errors.New("readSignedPost bad method: " + c.Method()) diff --git a/api/comms/validator.go b/api/comms/validator.go index 569ed30b..c8c2b07f 100644 --- a/api/comms/validator.go +++ b/api/comms/validator.go @@ -445,18 +445,6 @@ func validateChatMembership(pool *dbv1.DBPools, ctx context.Context, userId int3 return nil } -// TODO: REMOVE? -// Recheck chat permissions before sending further messages if a member of the chat -// has cleared their chat history -// func RecheckPermissionsRequired(lastMessageAt time.Time, members []db.ChatMember) bool { -// for _, member := range members { -// if member.ClearedHistoryAt.Valid && member.ClearedHistoryAt.Time.After(lastMessageAt) { -// return true -// } -// } -// return false -// } - func validatePermissions(pool *dbv1.DBPools, ctx context.Context, sender int32, receiver int32) error { permissionFailure := errors.New("Not permitted to send messages to this user") From 604f59f282b62ff0ff7388bb6893ced6376f17d4 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:15:27 -0400 Subject: [PATCH 5/7] some fixes --- api/comms/rpc_processor.go | 26 ++++++++++++++++++++++++++ api/comms/validator.go | 8 ++++---- api/comms_mutate.go | 12 +++++++++--- api/server.go | 1 + 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/api/comms/rpc_processor.go b/api/comms/rpc_processor.go index c5b439a8..a0873b1d 100644 --- a/api/comms/rpc_processor.go +++ b/api/comms/rpc_processor.go @@ -12,6 +12,7 @@ import ( "bridgerton.audius.co/trashid" "go.uber.org/zap" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/tidwall/gjson" @@ -130,6 +131,18 @@ func (proc *RPCProcessor) Apply(ctx context.Context, rpcLog *RpcLog) error { logger.Debug("begin tx", zap.Duration("took", takeSplit()), zap.String("sig", rpcLog.Sig)) + count, err := insertRpcLogRow(tx, ctx, rpcLog) + if err != nil { + return err + } + if count == 0 { + // No rows were inserted because the sig (id) is already in rpc_log. + // Do not process redelivered messages that have already been processed. + logger.Info("rpc already in log, skipping duplicate", zap.String("sig", rpcLog.Sig)) + return nil + } + logger.Debug("inserted RPC", zap.Duration("took", takeSplit())) + switch RPCMethod(rawRpc.Method) { case RPCMethodChatCreate: var params ChatCreateRPCParams @@ -329,6 +342,19 @@ func (proc *RPCProcessor) GetRPCCurrentUserID(ctx context.Context, rpcLog *RpcLo return userId, err } +func insertRpcLogRow(tx pgx.Tx, ctx context.Context, rpcLog *RpcLog) (int64, error) { + query := ` + INSERT INTO rpc_log (relayed_by, relayed_at, applied_at, from_wallet, rpc, sig) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT DO NOTHING + ` + result, err := tx.Exec(ctx, query, rpcLog.RelayedBy, rpcLog.RelayedAt, time.Now(), rpcLog.FromWallet, rpcLog.Rpc, rpcLog.Sig) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + // func websocketNotify(rpcJson json.RawMessage, userId int32, timestamp time.Time) { // if chatId := gjson.GetBytes(rpcJson, "params.chat_id").String(); chatId != "" { diff --git a/api/comms/validator.go b/api/comms/validator.go index c8c2b07f..ffe569ac 100644 --- a/api/comms/validator.go +++ b/api/comms/validator.go @@ -359,9 +359,9 @@ func (vtor *Validator) validateNewMessageRateLimit(pool *dbv1.DBPools, ctx conte { query := ` select - sum(case when created_at > now() - interval '1 second' then 1 else 0 end) as s1, - sum(case when created_at > now() - interval '10 seconds' then 1 else 0 end) as s10, - sum(case when created_at > now() - interval '60 seconds' then 1 else 0 end) as s60 + coalesce(sum(case when created_at > now() - interval '1 second' then 1 else 0 end), 0) as s1, + coalesce(sum(case when created_at > now() - interval '10 seconds' then 1 else 0 end), 0) as s10, + coalesce(sum(case when created_at > now() - interval '60 seconds' then 1 else 0 end), 0) as s60 from chat_message where user_id = $1 and created_at > now() - interval '60 seconds'; @@ -532,7 +532,7 @@ func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, log defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - logger.Warn("User failed AAO check", zap.Int32("userId", userId), zap.Int("status", resp.StatusCode)) + logger.Warn("User failed AAO check", zap.Int32("userId", userId), zap.Int("status", resp.StatusCode), zap.String("aaoServer", aaoServer)) return ErrAttestationFailed } return nil diff --git a/api/comms_mutate.go b/api/comms_mutate.go index 783e3f3b..b0365c77 100644 --- a/api/comms_mutate.go +++ b/api/comms_mutate.go @@ -41,9 +41,15 @@ func (app *ApiServer) mutateChat(c *fiber.Ctx) error { err = app.commsRpcProcessor.Validate(c.Context(), int32(userId), rawRpc) if err != nil { if errors.Is(err, comms.ErrAttestationFailed) { - return c.JSON(403, "bad request: "+err.Error()) + return fiber.NewError( + fiber.StatusForbidden, + "forbidden to make this request: "+err.Error(), + ) } - return c.JSON(400, "bad request: "+err.Error()) + return fiber.NewError( + fiber.StatusBadRequest, + "bad request: "+err.Error(), + ) } err = app.commsRpcProcessor.Apply(c.Context(), rpcLog) @@ -52,5 +58,5 @@ func (app *ApiServer) mutateChat(c *fiber.Ctx) error { return err } app.logger.Debug("comms rpc apply succeeded", zap.String("payload", string(payload)), zap.String("wallet", wallet), zap.Bool("relay", true)) - return c.JSON(200) + return c.JSON(rpcLog) } diff --git a/api/server.go b/api/server.go index d1f72d1b..b1356274 100644 --- a/api/server.go +++ b/api/server.go @@ -296,6 +296,7 @@ func NewApiServer(config config.Config) *ApiServer { app.Use("/v1/full/tracks/best_new_releases", BalancerForward(config.PythonUpstreams)) app.Use("/v1/full/tracks/most_loved", BalancerForward(config.PythonUpstreams)) app.Use("/v1/full/tracks/remixables", BalancerForward(config.PythonUpstreams)) + app.Use("/comms/chats/ws", BalancerForward(config.PythonUpstreams)) } v1 := app.Group("/v1") From 246840b99101077d48bb97eb74fd1ca9f4263178 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:30:48 -0400 Subject: [PATCH 6/7] PR feedback and cleanup --- api/comms/chat_blast.go | 6 ------ api/comms/signed_request.go | 9 --------- api/comms/validator.go | 6 +++++- api/comms_mutate.go | 1 - config/config.go | 4 ---- 5 files changed, 5 insertions(+), 21 deletions(-) diff --git a/api/comms/chat_blast.go b/api/comms/chat_blast.go index 5ee5862a..39db8f6e 100644 --- a/api/comms/chat_blast.go +++ b/api/comms/chat_blast.go @@ -8,12 +8,6 @@ import ( "github.com/jackc/pgx/v5" ) -/* -todo: - -- maybe blast_id should be computed like: `md5(from_user_id || audience || plaintext)` - -*/ // Result struct to hold chat_id and to_user_id type ChatBlastResult struct { ChatID string `db:"chat_id"` diff --git a/api/comms/signed_request.go b/api/comms/signed_request.go index 18b451f5..f6f3be8c 100644 --- a/api/comms/signed_request.go +++ b/api/comms/signed_request.go @@ -36,14 +36,5 @@ func recoverSigningWallet(signatureHex string, signedData []byte) (string, error wallet := crypto.PubkeyToAddress(*pubkey).Hex() - // TODO: Still need this? We have a function for getting these in another file - // seed the user pubkey if missing - // err = pubkeystore.SetPubkeyForWallet(wallet, pubkey) - // if err != nil { - // slog.Warn("failed to SetPubkeyForWallet", "wallet", wallet, "err", err) - // } else { - // slog.Info("SetPubkeyForWallet OK", "wallet", wallet) - // } - return wallet, nil } diff --git a/api/comms/validator.go b/api/comms/validator.go index ffe569ac..a0fd3490 100644 --- a/api/comms/validator.go +++ b/api/comms/validator.go @@ -28,10 +28,14 @@ type Validator struct { } func NewValidator(pool *dbv1.DBPools, limiter *RateLimiter, config *config.Config, logger *zap.Logger) *Validator { + if len(config.AntiAbuseOracles) == 0 { + panic("no anti-abuse oracles configured, can't initialize comms validator") + } + return &Validator{ pool: pool, limiter: limiter, - aaoServer: config.AAOServer, + aaoServer: config.AntiAbuseOracles[0], logger: logger, } } diff --git a/api/comms_mutate.go b/api/comms_mutate.go index b0365c77..e5afa369 100644 --- a/api/comms_mutate.go +++ b/api/comms_mutate.go @@ -37,7 +37,6 @@ func (app *ApiServer) mutateChat(c *fiber.Ctx) error { return err } - // TODO: Decide if we want to validate first and then apply or validate inside apply within a transaction err = app.commsRpcProcessor.Validate(c.Context(), int32(userId), rawRpc) if err != nil { if errors.Is(err, comms.ErrAttestationFailed) { diff --git a/config/config.go b/config/config.go index c299ff56..b6933e16 100644 --- a/config/config.go +++ b/config/config.go @@ -38,7 +38,6 @@ type Config struct { BirdeyeToken string SolanaIndexerWorkers int SolanaIndexerRetryInterval time.Duration - AAOServer string } var Cfg = Config{ @@ -89,7 +88,6 @@ func init() { Cfg.PythonUpstreams = []string{ "http://audius-protocol-discovery-provider-1", } - Cfg.AAOServer = "http://audius-protocol-discovery-provider-1" case "stage": fallthrough case "staging": @@ -105,7 +103,6 @@ func init() { Cfg.Rewards = core_config.MakeRewards(core_config.StageClaimAuthorities, core_config.StageRewardExtensions) Cfg.AudiusdURL = "creatornode11.staging.audius.co" Cfg.ChainId = "audius-testnet-alpha" - Cfg.AAOServer = "https://discoveryprovider.staging.audius.co" case "prod": fallthrough case "production": @@ -123,7 +120,6 @@ func init() { Cfg.Rewards = core_config.MakeRewards(core_config.ProdClaimAuthorities, core_config.ProdRewardExtensions) Cfg.AudiusdURL = "creatornode.audius.co" Cfg.ChainId = "audius-mainnet-alpha-beta" - Cfg.AAOServer = "https://discoveryprovider.audius.co" default: log.Fatalf("Unknown environment: %s", env) } From e03357a87da9d5131eed75301a71131851be6a06 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:35:31 -0400 Subject: [PATCH 7/7] temp test hacks --- api/comms/validator.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/api/comms/validator.go b/api/comms/validator.go index a0fd3490..873431ad 100644 --- a/api/comms/validator.go +++ b/api/comms/validator.go @@ -28,14 +28,19 @@ type Validator struct { } func NewValidator(pool *dbv1.DBPools, limiter *RateLimiter, config *config.Config, logger *zap.Logger) *Validator { - if len(config.AntiAbuseOracles) == 0 { + // TODO: Don't hack around this for tests + if len(config.AntiAbuseOracles) == 0 && config.Env != "test" { panic("no anti-abuse oracles configured, can't initialize comms validator") } + aaoServer := "" + if len(config.AntiAbuseOracles) > 0 { + aaoServer = config.AntiAbuseOracles[0] + } return &Validator{ pool: pool, limiter: limiter, - aaoServer: config.AntiAbuseOracles[0], + aaoServer: aaoServer, logger: logger, } } @@ -515,6 +520,9 @@ var ErrAttestationFailed = errors.New("attestation failed") // TODO: Better AAO usage that corresponds to the claim rewards code func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, logger *zap.Logger, userId int32, aaoServer string) error { + if aaoServer == "" { + return nil + } // Keeping this somewhat opaque as it gets sent to client var handle string err := pool.QueryRow(ctx, `SELECT handle FROM users WHERE user_id = $1`, userId).Scan(&handle)