diff --git a/api/comms/chat.go b/api/comms/chat.go new file mode 100644 index 00000000..fddab589 --- /dev/null +++ b/api/comms/chat.go @@ -0,0 +1,454 @@ +package comms + +import ( + "context" + "time" + + "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +func chatCreate(tx pgx.Tx, ctx context.Context, userId int32, ts time.Time, params ChatCreateRPCParams) error { + var err error + + // first find any blasts that should seed this chat ... + var blasts []blastRow + for _, invite := range params.Invites { + invitedUserId, err := trashid.DecodeHashId(invite.UserID) + if err != nil { + return err + } + + pending, err := getNewBlasts(tx, context.Background(), getNewBlastsParams{ + UserID: int32(invitedUserId), + ChatID: params.ChatID, + }) + if err != nil { + return err + } + blasts = append(blasts, pending...) + } + + // it is possible that two conflicting chats get created at the same time + // in which case there will be two different chat secrets + // to deterministically resolve this, if there is a conflict + // we keep the chat with the earliest relayed_at (created_at) timestamp + _, err = tx.Exec(ctx, ` + insert into chat + (chat_id, created_at, last_message_at) + values + ($1, $2, $2) + on conflict (chat_id) + do update set created_at = $2, last_message_at = $2 where chat.created_at > $2 + `, params.ChatID, ts) + if err != nil { + return err + } + + for _, invite := range params.Invites { + + invitedUserId, err := trashid.DecodeHashId(invite.UserID) + if err != nil { + return err + } + + // similar to above... if there is a conflict when creating chat_member records + // keep the version with the earliest relayed_at (created_at) timestamp. + _, err = tx.Exec(ctx, ` + insert into chat_member + (chat_id, invited_by_user_id, invite_code, user_id, created_at) + values + ($1, $2, $3, $4, $5) + on conflict (chat_id, user_id) + do update set invited_by_user_id=$2, invite_code=$3, created_at=$5 where chat_member.created_at > $5`, + params.ChatID, userId, invite.InviteCode, invitedUserId, ts) + if err != nil { + return err + } + + } + + for _, blast := range blasts { + _, err = tx.Exec(ctx, ` + insert into chat_message + (message_id, chat_id, user_id, created_at, blast_id) + values + ($1, $2, $3, $4, $5) + on conflict do nothing + `, trashid.BlastMessageID(blast.BlastID, params.ChatID), params.ChatID, blast.FromUserID, blast.CreatedAt, blast.BlastID) + if err != nil { + return err + } + } + + err = chatUpdateLatestFields(tx, ctx, params.ChatID) + + return err +} + +func chatDelete(tx pgx.Tx, ctx context.Context, userId int32, chatId string, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, "update chat_member set cleared_history_at = $1, last_active_at = $1, unread_count = 0, is_hidden = true where chat_id = $2 and user_id = $3", messageTimestamp, chatId, userId) + return err +} + +func chatUpdateLatestFields(tx pgx.Tx, ctx context.Context, chatId string) error { + // universal latest message thing + _, err := tx.Exec(ctx, ` + with latest as ( + select + m.chat_id, + m.created_at, + m.ciphertext, + m.blast_id, + b.plaintext + from + chat_message m + left join chat_blast b using (blast_id) + where m.chat_id = $1 + order by m.created_at desc + limit 1 + ) + update chat c + set + last_message_at = latest.created_at, + last_message = coalesce(latest.ciphertext, latest.plaintext), + last_message_is_plaintext = latest.blast_id is not null + from latest + where c.chat_id = latest.chat_id; + `, chatId) + if err != nil { + return err + } + + // set chat_member.is_hidden to false + // if there are any non-blast messages, reactions, + // or any blasts from the other party after cleared_history_at + _, err = tx.Exec(ctx, ` + UPDATE chat_member member + SET is_hidden = NOT EXISTS( + + -- Check for chat messages + SELECT msg.message_id + FROM chat_message msg + LEFT JOIN chat_blast b USING (blast_id) + WHERE msg.chat_id = member.chat_id + AND (cleared_history_at IS NULL OR msg.created_at > cleared_history_at) + AND (msg.blast_id IS NULL OR b.from_user_id != member.user_id) + + UNION + + -- Check for chat message reactions + SELECT r.message_id + FROM chat_message_reactions r + LEFT JOIN chat_message msg ON r.message_id = msg.message_id + WHERE msg.chat_id = member.chat_id + AND r.user_id != member.user_id + AND (cleared_history_at IS NULL OR (r.created_at > cleared_history_at AND msg.created_at > cleared_history_at)) + ), + unread_count = ( + select count(*) + from chat_message msg + where msg.created_at > COALESCE(member.last_active_at, '1970-01-01'::timestamp) + and msg.user_id != member.user_id + and msg.chat_id = member.chat_id + ) + WHERE member.chat_id = $1 + `, chatId) + return err +} + +func chatSendMessage(tx pgx.Tx, ctx context.Context, userId int32, chatId string, messageId string, messageTimestamp time.Time, ciphertext string) error { + var err error + + _, err = tx.Exec(ctx, "insert into chat_message (message_id, chat_id, user_id, created_at, ciphertext) values ($1, $2, $3, $4, $5)", + messageId, chatId, userId, messageTimestamp, ciphertext) + if err != nil { + return err + } + + // update chat's info on last message + err = chatUpdateLatestFields(tx, ctx, chatId) + if err != nil { + return err + } + + // sending a message implicitly marks activity for sender... + err = chatReadMessages(tx, ctx, userId, chatId, messageTimestamp) + if err != nil { + return err + } + + return err +} + +func chatReactMessage(tx pgx.Tx, ctx context.Context, userId int32, chatId string, messageId string, reaction *string, messageTimestamp time.Time) error { + var err error + if reaction != nil { + _, err = tx.Exec(ctx, ` + insert into chat_message_reactions + (user_id, message_id, reaction, created_at, updated_at) + values + ($1, $2, $3, $4, $4) + on conflict (user_id, message_id) + do update set reaction = $3, updated_at = $4 where chat_message_reactions.updated_at < $4`, + userId, messageId, *reaction, messageTimestamp) + } else { + _, err = tx.Exec(ctx, "delete from chat_message_reactions where user_id = $1 and message_id = $2 and updated_at < $3", userId, messageId, messageTimestamp) + } + if err != nil { + return err + } + + // update chat's info on reaction + err = chatUpdateLatestFields(tx, ctx, chatId) + return err +} + +func chatReadMessages(tx pgx.Tx, ctx context.Context, userId int32, chatId string, readTimestamp time.Time) error { + _, err := tx.Exec(ctx, "update chat_member set unread_count = 0, last_active_at = $1 where chat_id = $2 and user_id = $3", + readTimestamp, chatId, userId) + return err +} + +var permissions = []ChatPermission{ + ChatPermissionFollowees, + ChatPermissionFollowers, + ChatPermissionTippees, + ChatPermissionTippers, + ChatPermissionVerified, +} + +// Helper function to check if a permit is in the permitList +func isInPermitList(permit ChatPermission, permitList []ChatPermission) bool { + for _, p := range permitList { + if p == permit { + return true + } + } + return false +} + +func updatePermissions(tx pgx.Tx, ctx context.Context, userId int32, permit ChatPermission, permitAllowed bool, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, ` + insert into chat_permissions (user_id, permits, allowed, updated_at) + values ($1, $2, $3, $4) + on conflict (user_id, permits) + do update set allowed = $3 where chat_permissions.updated_at < $4 + `, userId, permit, permitAllowed, messageTimestamp) + return err +} + +func chatSetPermissions(tx pgx.Tx, ctx context.Context, userId int32, permits ChatPermission, permitList []ChatPermission, allow *bool, messageTimestamp time.Time) error { + + // if "all" or "none" or is singular permission style (allow == nil) delete any old rows + if allow == nil || permits == ChatPermissionAll || permits == ChatPermissionNone || isInPermitList(ChatPermissionAll, permitList) || isInPermitList(ChatPermissionNone, permitList) { + _, err := tx.Exec(ctx, ` + delete from chat_permissions where user_id = $1 and updated_at < $2 + `, userId, messageTimestamp) + if err != nil { + return err + } + } + + // old: singular permission style + if allow == nil { + // insert + _, err := tx.Exec(ctx, ` + insert into chat_permissions (user_id, permits, updated_at) + values ($1, $2, $3) + on conflict do nothing`, userId, permits, messageTimestamp) + return err + } + + // Special case for "all" and "none" - no other rows should be inserted + if isInPermitList(ChatPermissionAll, permitList) { + err := updatePermissions(tx, ctx, userId, ChatPermissionAll, true, messageTimestamp) + return err + } else if isInPermitList(ChatPermissionNone, permitList) { + err := updatePermissions(tx, ctx, userId, ChatPermissionNone, true, messageTimestamp) + return err + } + + // new: multiple (checkbox) permission style + for _, permit := range permissions { + permitAllowed := isInPermitList(permit, permitList) + err := updatePermissions(tx, ctx, userId, permit, permitAllowed, messageTimestamp) + if err != nil { + return err + } + } + return nil +} + +func chatBlock(tx pgx.Tx, ctx context.Context, userId int32, blockeeUserId int32, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, "insert into chat_blocked_users (blocker_user_id, blockee_user_id, created_at) values ($1, $2, $3) on conflict do nothing", userId, blockeeUserId, messageTimestamp) + return err +} + +func chatUnblock(tx pgx.Tx, ctx context.Context, userId int32, unblockedUserId int32, messageTimestamp time.Time) error { + _, err := tx.Exec(ctx, "delete from chat_blocked_users where blocker_user_id = $1 and blockee_user_id = $2 and created_at < $3", userId, unblockedUserId, messageTimestamp) + return err +} + +type blastRow struct { + PendingChatID string `db:"-" json:"pending_chat_id"` + BlastID string `db:"blast_id" json:"blast_id"` + FromUserID int32 `db:"from_user_id" json:"from_user_id"` + Audience string `db:"audience" json:"audience"` + AudienceContentType pgtype.Text `db:"audience_content_type" json:"audience_content_type"` + AudienceContentID pgtype.Int4 `db:"audience_content_id" json:"-"` + AudienceContentIDEncoded pgtype.Text `db:"-" json:"audience_content_id"` + Plaintext string `db:"plaintext" json:"plaintext"` + CreatedAt time.Time `db:"created_at" json:"created_at"` +} + +type getNewBlastsParams struct { + UserID int32 `db:"user_id" json:"user_id"` + ChatID string `db:"chat_id" json:"chat_id"` +} + +// Helper function to get new blasts as potential chat seeds for creating a chat +func getNewBlasts(tx pgx.Tx, ctx context.Context, arg getNewBlastsParams) ([]blastRow, error) { + + // this query is to find new blasts for the current user + // which don't already have a existing chat. + // see also: subtly different inverse query exists in chat_blast.go + // to fan out messages to existing chat + var findNewBlasts = ` + with + last_permission_change as ( + select max(t) as t from ( + select updated_at as t from chat_permissions where user_id = $1 + union + select created_at as t from chat_blocked_users where blocker_user_id = $1 + union + select to_timestamp(0) + ) as timestamp_subquery + ), + all_new as ( + select * + from chat_blast blast + where + from_user_id in ( + -- follower_audience + SELECT followee_user_id AS from_user_id + FROM follows + WHERE blast.audience = 'follower_audience' + AND follows.followee_user_id = blast.from_user_id + AND follows.follower_user_id = $1 + AND follows.is_delete = false + AND follows.created_at < blast.created_at + ) + OR from_user_id in ( + -- tipper_audience + SELECT receiver_user_id + FROM user_tips tip + WHERE blast.audience = 'tipper_audience' + AND receiver_user_id = blast.from_user_id + AND sender_user_id = $1 + AND tip.created_at < blast.created_at + ) + OR from_user_id IN ( + -- remixer_audience + SELECT og.owner_id + FROM tracks t + JOIN remixes ON remixes.child_track_id = t.track_id + JOIN tracks og ON remixes.parent_track_id = og.track_id + WHERE blast.audience = 'remixer_audience' + AND og.owner_id = blast.from_user_id + AND t.owner_id = $1 + AND ( + blast.audience_content_id IS NULL + OR ( + blast.audience_content_type = 'track' + AND blast.audience_content_id = og.track_id + ) + ) + ) + OR from_user_id IN ( + -- customer_audience + SELECT seller_user_id + FROM usdc_purchases p + WHERE blast.audience = 'customer_audience' + AND p.seller_user_id = blast.from_user_id + AND p.buyer_user_id = $1 + AND ( + audience_content_id IS NULL + OR ( + blast.audience_content_type = p.content_type::text + AND blast.audience_content_id = p.content_id + ) + ) + ) + OR from_user_id IN ( + -- coin_holder_audience via sol_user_balances + SELECT ac.user_id + FROM artist_coins ac + JOIN sol_user_balances sub ON sub.mint = ac.mint + WHERE blast.audience = 'coin_holder_audience' + AND ac.user_id = blast.from_user_id + AND sub.user_id = $1 + AND sub.balance > 0 + -- TODO: PE-6663 This isn't entirely correct yet, need to check "time of most recent membership" + AND sub.created_at < blast.created_at + ) + ) + select * from all_new + where created_at > (select t from last_permission_change) + and chat_allowed(from_user_id, $1) + order by created_at + ` + + rows, err := tx.Query(ctx, findNewBlasts, arg.UserID) + if err != nil { + return nil, err + } + + items, err := pgx.CollectRows(rows, pgx.RowToStructByName[blastRow]) + if err != nil { + return nil, err + } + + for idx, blastRow := range items { + chatId := trashid.ChatID(int(arg.UserID), int(blastRow.FromUserID)) + items[idx].PendingChatID = chatId + + if blastRow.AudienceContentID.Valid { + encoded, _ := trashid.EncodeHashId(int(blastRow.AudienceContentID.Int32)) + items[idx].AudienceContentIDEncoded.String = encoded + items[idx].AudienceContentIDEncoded.Valid = true + } + } + + rows, err = tx.Query(ctx, `select chat_id from chat_member where user_id = $1`, arg.UserID) + if err != nil { + return nil, err + } + + existingChatIdList, err := pgx.CollectRows(rows, pgx.RowTo[string]) + if err != nil { + return nil, err + } + + existingChatIds := map[string]bool{} + for _, id := range existingChatIdList { + existingChatIds[id] = true + } + + // filter out blast rows where chatIds is taken + filtered := make([]blastRow, 0, len(items)) + for _, item := range items { + if existingChatIds[item.PendingChatID] { + continue + } + // allow caller to filter to blasts for a given chat ID + if arg.ChatID != "" && item.PendingChatID != arg.ChatID { + continue + } + filtered = append(filtered, item) + } + + return filtered, err + +} diff --git a/api/comms/chat_blast.go b/api/comms/chat_blast.go new file mode 100644 index 00000000..39db8f6e --- /dev/null +++ b/api/comms/chat_blast.go @@ -0,0 +1,114 @@ +package comms + +import ( + "context" + "time" + + "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5" +) + +// Result struct to hold chat_id and to_user_id +type ChatBlastResult struct { + ChatID string `db:"chat_id"` + ToUserID int32 `db:"to_user_id"` +} + +type OutgoingChatMessage struct { + ChatMessageRPC ChatMessageRPC `json:"chat_message_rpc"` +} + +func chatBlast(tx pgx.Tx, ctx context.Context, userId int32, ts time.Time, params ChatBlastRPCParams) ([]OutgoingChatMessage, error) { + var audienceContentID *int + if params.AudienceContentID != nil { + id, _ := trashid.DecodeHashId(*params.AudienceContentID) + audienceContentID = &id + } + + // insert params.Message into chat_blast table + _, err := tx.Exec(ctx, ` + insert into chat_blast + (blast_id, from_user_id, audience, audience_content_type, audience_content_id, plaintext, created_at) + values + ($1, $2, $3, $4, $5, $6, $7) + on conflict (blast_id) + do nothing + `, params.BlastID, userId, params.Audience, params.AudienceContentType, audienceContentID, params.Message, ts) + if err != nil { + return nil, err + } + + // fan out messages to existing threads + // see also: similar but subtly different inverse query in `getNewBlasts helper in chat.go` + var results []ChatBlastResult + + fanOutSql := ` + WITH targ AS ( + SELECT + blast_id, + from_user_id, + to_user_id, + member_b.chat_id + FROM chat_blast + JOIN chat_blast_audience(chat_blast.blast_id) USING (blast_id) + LEFT JOIN chat_member member_a on from_user_id = member_a.user_id + LEFT JOIN chat_member member_b on to_user_id = member_b.user_id and member_b.chat_id = member_a.chat_id + WHERE blast_id = $1 + AND member_b.chat_id IS NOT NULL + AND chat_allowed(from_user_id, to_user_id) + ), + insert_message AS ( + INSERT INTO chat_message + (message_id, chat_id, user_id, created_at, blast_id) + SELECT + blast_id || targ.chat_id, -- this ordering needs to match Misc.BlastMessageID + targ.chat_id, + targ.from_user_id, + $2, + blast_id + FROM targ + ON conflict do nothing + ) + SELECT chat_id FROM targ; + ` + + rows, err := tx.Query(ctx, fanOutSql, params.BlastID, ts) + if err != nil { + return nil, err + } + defer rows.Close() + + // Scan the results into the results slice + results, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (ChatBlastResult, error) { + var result ChatBlastResult + err := row.Scan(&result.ChatID, &result.ToUserID) + return result, err + }) + if err != nil { + return nil, err + } + + // Formulate chat rpc messages for recipients who have an existing chat with sender + var outgoingMessages []OutgoingChatMessage + for _, result := range results { + messageID := trashid.BlastMessageID(params.BlastID, result.ChatID) + + isPlaintext := true + outgoingMessages = append(outgoingMessages, OutgoingChatMessage{ + ChatMessageRPC: ChatMessageRPC{ + Method: MethodChatMessage, + Params: ChatMessageRPCParams{ + ChatID: result.ChatID, + Message: params.Message, + MessageID: messageID, + IsPlaintext: &isPlaintext, + Audience: ¶ms.Audience, + }}}) + + if err := chatUpdateLatestFields(tx, ctx, result.ChatID); err != nil { + return nil, err + } + } + + return outgoingMessages, nil +} diff --git a/api/comms/constants.go b/api/comms/constants.go new file mode 100644 index 00000000..6e642159 --- /dev/null +++ b/api/comms/constants.go @@ -0,0 +1,21 @@ +package comms + +var ( + // TODO: verify this is correct + SigHeader = "x-sig" + SignatureTimeToLiveMs = int64(1000 * 60 * 60 * 12) // 12 hours + + // Rate limit config + RateLimitRulesBucketName = "rateLimitRules" + RateLimitTimeframeHours = "timeframeHours" + RateLimitMaxNumMessages = "maxNumMessages" + RateLimitMaxNumMessagesPerRecipient = "maxNumMessagesPerRecipient" + RateLimitMaxNumNewChats = "maxNumNewChats" + + DefaultRateLimitRules = map[string]int{ + RateLimitTimeframeHours: 24, + RateLimitMaxNumMessages: 2000, + RateLimitMaxNumMessagesPerRecipient: 1000, + RateLimitMaxNumNewChats: 100000, + } +) diff --git a/api/comms/rate_limit.go b/api/comms/rate_limit.go new file mode 100644 index 00000000..5b940215 --- /dev/null +++ b/api/comms/rate_limit.go @@ -0,0 +1,30 @@ +package comms + +import ( + "sync" +) + +func NewRateLimiter() (*RateLimiter, error) { + + limiter := &RateLimiter{ + limits: map[string]int{}, + } + + return limiter, nil +} + +type RateLimiter struct { + sync.RWMutex + limits map[string]int +} + +func (limiter *RateLimiter) Get(rule string) int { + limiter.RLock() + defer limiter.RUnlock() + + if val := limiter.limits[rule]; val != 0 { + return val + } + + return DefaultRateLimitRules[rule] +} diff --git a/api/comms/raw_rpc.go b/api/comms/raw_rpc.go new file mode 100644 index 00000000..7cd4d5c1 --- /dev/null +++ b/api/comms/raw_rpc.go @@ -0,0 +1,16 @@ +package comms + +import ( + "encoding/json" +) + +// RawRPC matches (quicktype generated) RPC +// Except Params is a json.RawMessage instead of a quicktype approximation of a golang union type which sadly doesn't really exist. +// which is more generic + convienent to use in go code +// it should match the fields of RPC +type RawRPC struct { + CurrentUserID string `json:"current_user_id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + Timestamp int64 `json:"timestamp"` +} diff --git a/api/comms/rpc_log.go b/api/comms/rpc_log.go new file mode 100644 index 00000000..2af4eb62 --- /dev/null +++ b/api/comms/rpc_log.go @@ -0,0 +1,20 @@ +package comms + +import ( + "encoding/json" + "time" +) + +// RpcLog was previously used to track messages sent between comms peer servers. +// We are now using it as a record of RPC requests received from clients. +// RelayedAt will be the timestamp when the server received the request. +// RelayedBy will be hard-coded to "bridge" to differentiate it from legacy rpclog messages. +type RpcLog struct { + ID string `db:"id" json:"id"` + RelayedAt time.Time `db:"relayed_at" json:"relayed_at"` + AppliedAt time.Time `db:"applied_at" json:"applied_at"` + RelayedBy string `db:"relayed_by" json:"relayed_by"` + FromWallet string `db:"from_wallet" json:"from_wallet"` + Rpc json.RawMessage `db:"rpc" json:"rpc"` + Sig string `db:"sig" json:"sig"` +} diff --git a/api/comms/rpc_processor.go b/api/comms/rpc_processor.go new file mode 100644 index 00000000..a0873b1d --- /dev/null +++ b/api/comms/rpc_processor.go @@ -0,0 +1,378 @@ +package comms + +import ( + "context" + "encoding/json" + "strings" + "sync" + "time" + + "bridgerton.audius.co/api/dbv1" + "bridgerton.audius.co/config" + "bridgerton.audius.co/trashid" + "go.uber.org/zap" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/tidwall/gjson" +) + +type RPCProcessor struct { + sync.Mutex + pool *dbv1.DBPools + writePool *pgxpool.Pool + validator *Validator + logger *zap.Logger +} + +func NewProcessor(pool *dbv1.DBPools, writePool *pgxpool.Pool, config *config.Config, logger *zap.Logger) (*RPCProcessor, error) { + + // set up validator + limiter + limiter, err := NewRateLimiter() + if err != nil { + return nil, err + } + + validator := NewValidator(pool, limiter, config, logger) + + proc := &RPCProcessor{ + validator: validator, + pool: pool, + writePool: writePool, + logger: logger, + } + + return proc, nil +} + +func (proc *RPCProcessor) Validate(ctx context.Context, userId int32, rawRpc RawRPC) error { + return proc.validator.Validate(ctx, userId, rawRpc) +} + +// Validates + applies a message +func (proc *RPCProcessor) Apply(ctx context.Context, rpcLog *RpcLog) error { + logger := proc.logger.With( + zap.String("sig", rpcLog.Sig), + ) + var err error + + var exists int + proc.pool.QueryRow(ctx, `select count(*) from rpc_log where sig = $1`, rpcLog.Sig).Scan(&exists) + if exists == 1 { + logger.Debug("rpc already in log, skipping duplicate", zap.String("sig", rpcLog.Sig)) + return nil + } + + startTime := time.Now() + takeSplit := func() time.Duration { + split := time.Since(startTime) + startTime = time.Now() + return split + } + + // validate signing wallet + wallet, err := recoverSigningWallet(rpcLog.Sig, rpcLog.Rpc) + if err != nil { + logger.Warn("unable to recover wallet, skipping") + return nil + } + logger.Debug("recovered wallet", zap.Duration("took", takeSplit())) + + if wallet != rpcLog.FromWallet { + logger.Warn("recovered wallet no match", zap.String("recovered", wallet), zap.String("expected", rpcLog.FromWallet), zap.String("realeyd_by", rpcLog.RelayedBy)) + return nil + } + + // parse raw rpc + var rawRpc RawRPC + err = json.Unmarshal(rpcLog.Rpc, &rawRpc) + if err != nil { + logger.Info(err.Error()) + return nil + } + + // check for "internal" message, which are from the legacy implementation + if strings.HasPrefix(rawRpc.Method, "internal.") { + logger.Warn("recieved internal message, skipping") + return nil + } + + // get ts + messageTs := rpcLog.RelayedAt + + userId, err := proc.GetRPCCurrentUserID(ctx, rpcLog, &rawRpc) + if err != nil { + logger.Info("unable to get user ID") + return err // or nil? + } + + // for debugging + chatId := gjson.GetBytes(rpcLog.Rpc, "params.chat_id").String() + + logger = logger.With( + zap.String("wallet", wallet), + zap.Int32("userId", userId), + zap.String("relayed_by", rpcLog.RelayedBy), + zap.Time("relayed_at", rpcLog.RelayedAt), + zap.String("chat_id", chatId), + zap.String("sig", rpcLog.Sig), + ) + logger.Debug("got user", zap.Duration("took", takeSplit())) + + attemptApply := func() error { + + // write to db + tx, err := proc.writePool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + logger.Debug("begin tx", zap.Duration("took", takeSplit()), zap.String("sig", rpcLog.Sig)) + + count, err := insertRpcLogRow(tx, ctx, rpcLog) + if err != nil { + return err + } + if count == 0 { + // No rows were inserted because the sig (id) is already in rpc_log. + // Do not process redelivered messages that have already been processed. + logger.Info("rpc already in log, skipping duplicate", zap.String("sig", rpcLog.Sig)) + return nil + } + logger.Debug("inserted RPC", zap.Duration("took", takeSplit())) + + switch RPCMethod(rawRpc.Method) { + case RPCMethodChatCreate: + var params ChatCreateRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatCreate(tx, ctx, userId, messageTs, params) + if err != nil { + return err + } + case RPCMethodChatDelete: + var params ChatDeleteRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatDelete(tx, ctx, userId, params.ChatID, messageTs) + if err != nil { + return err + } + case RPCMethodChatMessage: + var params ChatMessageRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatSendMessage(tx, ctx, userId, params.ChatID, params.MessageID, messageTs, params.Message) + if err != nil { + return err + } + case RPCMethodChatReact: + var params ChatReactRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatReactMessage(tx, ctx, userId, params.ChatID, params.MessageID, params.Reaction, messageTs) + if err != nil { + return err + } + + case RPCMethodChatRead: + var params ChatReadRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + + // do nothing if last active at >= message timestamp + var lastActive pgtype.Timestamp + const lastActiveAtQuery = ` +select last_active_at from chat_member where chat_id = $1 and user_id = $2` + + err = tx.QueryRow(ctx, lastActiveAtQuery, params.ChatID, userId).Scan(&lastActive) + if err != nil { + return err + } + + if !lastActive.Valid || messageTs.After(lastActive.Time) { + err = chatReadMessages(tx, ctx, userId, params.ChatID, messageTs) + if err != nil { + return err + } + } + case RPCMethodChatPermit: + var params ChatPermitRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + err = chatSetPermissions(tx, ctx, userId, params.Permit, params.PermitList, params.Allow, messageTs) + if err != nil { + return err + } + case RPCMethodChatBlock: + var params ChatBlockRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + blockeeUserId, err := trashid.DecodeHashId(params.UserID) + if err != nil { + return err + } + err = chatBlock(tx, ctx, userId, int32(blockeeUserId), messageTs) + if err != nil { + return err + } + case RPCMethodChatUnblock: + var params ChatUnblockRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + unblockedUserId, err := trashid.DecodeHashId(params.UserID) + if err != nil { + return err + } + err = chatUnblock(tx, ctx, userId, int32(unblockedUserId), messageTs) + if err != nil { + return err + } + + case RPCMethodChatBlast: + var params ChatBlastRPCParams + err = json.Unmarshal(rawRpc.Params, ¶ms) + if err != nil { + return err + } + + outgoingMessages, err := chatBlast(tx, ctx, userId, messageTs, params) + if err != nil { + return err + } + // Send chat message websocket event to all recipients who have existing chats + for _, outgoingMessage := range outgoingMessages { + _, err := json.Marshal(outgoingMessage.ChatMessageRPC) + if err != nil { + logger.Error("err: invalid json", zap.Error(err)) + } else { + // TODO + // websocketNotify(json.RawMessage(j), userId, messageTs.Round(time.Microsecond)) + } + } + default: + logger.Warn("no handler for ", zap.String("method", rawRpc.Method)) + } + + logger.Debug("called handler", zap.Duration("took", takeSplit())) + + err = tx.Commit(ctx) + if err != nil { + return err + } + logger.Debug("commited", zap.Duration("took", takeSplit())) + + // TODO + // send out websocket events fire + forget style + // websocketNotify(rpcLog.Rpc, userId, messageTs.Round(time.Microsecond)) + logger.Debug("websocket push done", zap.Duration("took", takeSplit())) + + return nil + } + + err = attemptApply() + if err != nil { + logger.Warn("apply failed", zap.Error(err)) + } + return err +} + +func (proc *RPCProcessor) GetRPCCurrentUserID(ctx context.Context, rpcLog *RpcLog, rawRpc *RawRPC) (int32, error) { + walletAddress := rpcLog.FromWallet + encodedCurrentUserId := rawRpc.CurrentUserID + + // attempt to read the (newly added) current_user_id field + if encodedCurrentUserId != "" { + if u, err := trashid.DecodeHashId(encodedCurrentUserId); err == nil && u > 0 { + + const checkCurrentUserQuery = ` + select count(*) > 0 + from users + where is_current = true + and user_id = $1 + and wallet = lower($2) + and handle IS NOT NULL + and is_available = TRUE + and is_deactivated = FALSE + ` + // valid current_user_id + wallet combo? + // for now just check that the pair exists in the user table + // in the future this can check a "grants" table that a given operation is permitted + isValid := false + err := proc.pool.QueryRow(ctx, checkCurrentUserQuery, u, walletAddress).Scan(&isValid) + if err == nil && isValid { + return int32(u), nil + } else { + proc.logger.Warn("invalid current_user_id", zap.Error(err), zap.String("wallet", walletAddress), zap.String("encoded_user_id", encodedCurrentUserId), zap.Int("user_id", u)) + } + } + } + + const getUserIDFromWalletQuery = ` + select user_id + from users + where is_current = TRUE + and handle IS NOT NULL + and is_available = TRUE + and is_deactivated = FALSE + and wallet = LOWER($1) + order by user_id asc + ` + // fallback to looking up user_id using wallet alone + var userId int32 + err := proc.pool.QueryRow(ctx, getUserIDFromWalletQuery, walletAddress).Scan(&userId) + return userId, err +} + +func insertRpcLogRow(tx pgx.Tx, ctx context.Context, rpcLog *RpcLog) (int64, error) { + query := ` + INSERT INTO rpc_log (relayed_by, relayed_at, applied_at, from_wallet, rpc, sig) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT DO NOTHING + ` + result, err := tx.Exec(ctx, query, rpcLog.RelayedBy, rpcLog.RelayedAt, time.Now(), rpcLog.FromWallet, rpcLog.Rpc, rpcLog.Sig) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + +// func websocketNotify(rpcJson json.RawMessage, userId int32, timestamp time.Time) { +// if chatId := gjson.GetBytes(rpcJson, "params.chat_id").String(); chatId != "" { + +// var userIds []int32 +// err := db.Conn.Select(&userIds, `select user_id from chat_member where chat_id = $1 and is_hidden = false`, chatId) +// if err != nil { +// logger.Warn("failed to load chat members for websocket push " + err.Error()) +// return +// } + +// for _, receiverUserId := range userIds { +// websocketPush(userId, receiverUserId, rpcJson, timestamp) +// } +// } else if gjson.GetBytes(rpcJson, "method").String() == "chat.blast" { +// go func() { +// // Add delay before broadcasting blast messages - see PAY-3573 +// time.Sleep(30 * time.Second) +// websocketPushAll(userId, rpcJson, timestamp) +// }() +// } +// } diff --git a/api/comms/schema.go b/api/comms/schema.go new file mode 100644 index 00000000..3aac7523 --- /dev/null +++ b/api/comms/schema.go @@ -0,0 +1,415 @@ +package comms + +type ValidateCanChatRPC struct { + Method ValidateCanChatRPCMethod `json:"method"` + Params ValidateCanChatRPCParams `json:"params"` +} + +type ValidateCanChatRPCParams struct { + ReceiverUserIDS []string `json:"receiver_user_ids"` +} + +type ChatBlastRPC struct { + Method ChatBlastRPCMethod `json:"method"` + Params ChatBlastRPCParams `json:"params"` +} + +type ChatBlastRPCParams struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID string `json:"blast_id"` + Message string `json:"message"` +} + +type ChatCreateRPC struct { + Method ChatCreateRPCMethod `json:"method"` + Params ChatCreateRPCParams `json:"params"` +} + +type ChatCreateRPCParams struct { + ChatID string `json:"chat_id"` + Invites []PurpleInvite `json:"invites"` +} + +type PurpleInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ChatDeleteRPC struct { + Method ChatDeleteRPCMethod `json:"method"` + Params ChatDeleteRPCParams `json:"params"` +} + +type ChatDeleteRPCParams struct { + ChatID string `json:"chat_id"` +} + +type ChatInviteRPC struct { + Method ChatInviteRPCMethod `json:"method"` + Params ChatInviteRPCParams `json:"params"` +} + +type ChatInviteRPCParams struct { + ChatID string `json:"chat_id"` + Invites []FluffyInvite `json:"invites"` +} + +type FluffyInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ChatMessageRPC struct { + Method ChatMessageRPCMethod `json:"method"` + Params ChatMessageRPCParams `json:"params"` +} + +type ChatMessageRPCParams struct { + ChatID string `json:"chat_id"` + IsPlaintext *bool `json:"is_plaintext,omitempty"` + Message string `json:"message"` + MessageID string `json:"message_id"` + ParentMessageID *string `json:"parent_message_id,omitempty"` + Audience *ChatBlastAudience `json:"audience,omitempty"` +} + +type ChatReactRPC struct { + Method ChatReactRPCMethod `json:"method"` + Params ChatReactRPCParams `json:"params"` +} + +type ChatReactRPCParams struct { + ChatID string `json:"chat_id"` + MessageID string `json:"message_id"` + Reaction *string `json:"reaction"` +} + +type ChatReadRPC struct { + Method ChatReadRPCMethod `json:"method"` + Params ChatReadRPCParams `json:"params"` +} + +type ChatReadRPCParams struct { + ChatID string `json:"chat_id"` +} + +type ChatBlockRPC struct { + Method ChatBlockRPCMethod `json:"method"` + Params ChatBlockRPCParams `json:"params"` +} + +type ChatBlockRPCParams struct { + UserID string `json:"user_id"` +} + +type ChatUnblockRPC struct { + Method ChatUnblockRPCMethod `json:"method"` + Params ChatUnblockRPCParams `json:"params"` +} + +type ChatUnblockRPCParams struct { + UserID string `json:"user_id"` +} + +type ChatPermitRPC struct { + Method ChatPermitRPCMethod `json:"method"` + Params ChatPermitRPCParams `json:"params"` +} + +type ChatPermitRPCParams struct { + Allow *bool `json:"allow,omitempty"` + Permit ChatPermission `json:"permit"` + PermitList []ChatPermission `json:"permit_list"` +} + +type RPCPayloadRequest struct { + Method RPCMethod `json:"method"` + Params RPCPayloadRequestParams `json:"params"` +} + +type RPCPayloadRequestParams struct { + ReceiverUserIDS []string `json:"receiver_user_ids,omitempty"` + Audience *ChatBlastAudience `json:"audience,omitempty"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID *string `json:"blast_id,omitempty"` + Message *string `json:"message,omitempty"` + ChatID *string `json:"chat_id,omitempty"` + Invites []TentacledInvite `json:"invites,omitempty"` + IsPlaintext *bool `json:"is_plaintext,omitempty"` + MessageID *string `json:"message_id,omitempty"` + ParentMessageID *string `json:"parent_message_id,omitempty"` + Reaction *string `json:"reaction"` + UserID *string `json:"user_id,omitempty"` + Allow *bool `json:"allow,omitempty"` + Permit *ChatPermission `json:"permit,omitempty"` + PermitList []ChatPermission `json:"permit_list,omitempty"` +} + +type TentacledInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type UserChat struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *string `json:"audience_content_type,omitempty"` + ChatID string `json:"chat_id"` + ChatMembers []ChatMember `json:"chat_members"` + ClearedHistoryAt string `json:"cleared_history_at"` + InviteCode string `json:"invite_code"` + IsBlast bool `json:"is_blast"` + LastMessage string `json:"last_message"` + LastMessageAt string `json:"last_message_at"` + LastMessageIsPlaintext bool `json:"last_message_is_plaintext"` + LastReadAt string `json:"last_read_at"` + RecheckPermissions bool `json:"recheck_permissions"` + UnreadMessageCount float64 `json:"unread_message_count"` +} + +type ChatMember struct { + UserID string `json:"user_id"` +} + +type ChatMessageReaction struct { + CreatedAt string `json:"created_at"` + Reaction string `json:"reaction"` + UserID string `json:"user_id"` +} + +type ChatMessageNullableReaction struct { + CreatedAt string `json:"created_at"` + Reaction *string `json:"reaction"` + UserID string `json:"user_id"` +} + +type ChatMessage struct { + CreatedAt string `json:"created_at"` + IsPlaintext bool `json:"is_plaintext"` + Message string `json:"message"` + MessageID string `json:"message_id"` + Reactions []Reaction `json:"reactions"` + SenderUserID string `json:"sender_user_id"` +} + +type Reaction struct { + CreatedAt string `json:"created_at"` + Reaction string `json:"reaction"` + UserID string `json:"user_id"` +} + +type ChatInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ChatBlastBase struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + ChatID string `json:"chat_id"` +} + +type UpgradableChatBlast struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID string `json:"blast_id"` + ChatID string `json:"chat_id"` + CreatedAt string `json:"created_at"` + FromUserID float64 `json:"from_user_id"` + PendingChatID string `json:"pending_chat_id"` + Plaintext string `json:"plaintext"` +} + +type ChatBlast struct { + Audience ChatBlastAudience `json:"audience"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + ChatID string `json:"chat_id"` + IsBlast bool `json:"is_blast"` + LastMessageAt string `json:"last_message_at"` +} + +type ValidatedChatPermissions struct { + CurrentUserHasPermission bool `json:"current_user_has_permission"` + PermitList []ChatPermission `json:"permit_list"` + Permits ChatPermission `json:"permits"` + UserID string `json:"user_id"` +} + +type CommsResponse struct { + Data interface{} `json:"data"` + Health Health `json:"health"` + Summary *Summary `json:"summary,omitempty"` +} + +type Health struct { + IsHealthy bool `json:"is_healthy"` +} + +type Summary struct { + NextCount float64 `json:"next_count"` + NextCursor string `json:"next_cursor"` + PrevCount float64 `json:"prev_count"` + PrevCursor string `json:"prev_cursor"` + TotalCount float64 `json:"total_count"` +} + +type ChatWebsocketEventData struct { + Metadata Metadata `json:"metadata"` + RPC RPCPayload `json:"rpc"` +} + +type Metadata struct { + ReceiverUserID string `json:"receiverUserId"` + SenderUserID string `json:"senderUserId"` + Timestamp string `json:"timestamp"` + UserID string `json:"userId"` +} + +type RPCPayload struct { + CurrentUserID string `json:"current_user_id"` + Method RPCMethod `json:"method"` + Params RPCPayloadParams `json:"params"` + Timestamp float64 `json:"timestamp"` +} + +type RPCPayloadParams struct { + ReceiverUserIDS []string `json:"receiver_user_ids,omitempty"` + Audience *ChatBlastAudience `json:"audience,omitempty"` + AudienceContentID *string `json:"audience_content_id,omitempty"` + AudienceContentType *AudienceContentType `json:"audience_content_type,omitempty"` + BlastID *string `json:"blast_id,omitempty"` + Message *string `json:"message,omitempty"` + ChatID *string `json:"chat_id,omitempty"` + Invites []StickyInvite `json:"invites,omitempty"` + IsPlaintext *bool `json:"is_plaintext,omitempty"` + MessageID *string `json:"message_id,omitempty"` + ParentMessageID *string `json:"parent_message_id,omitempty"` + Reaction *string `json:"reaction"` + UserID *string `json:"user_id,omitempty"` + Allow *bool `json:"allow,omitempty"` + Permit *ChatPermission `json:"permit,omitempty"` + PermitList []ChatPermission `json:"permit_list,omitempty"` +} + +type StickyInvite struct { + InviteCode string `json:"invite_code"` + UserID string `json:"user_id"` +} + +type ValidateCanChatRPCMethod string + +const ( + MethodUserValidateCanChat ValidateCanChatRPCMethod = "user.validate_can_chat" +) + +type ChatBlastRPCMethod string + +const ( + MethodChatBlast ChatBlastRPCMethod = "chat.blast" +) + +type ChatBlastAudience string + +const ( + CustomerAudience ChatBlastAudience = "customer_audience" + FollowerAudience ChatBlastAudience = "follower_audience" + RemixerAudience ChatBlastAudience = "remixer_audience" + TipperAudience ChatBlastAudience = "tipper_audience" + CoinHolderAudience ChatBlastAudience = "coin_holder_audience" +) + +type AudienceContentType string + +const ( + Album AudienceContentType = "album" + Track AudienceContentType = "track" +) + +type ChatCreateRPCMethod string + +const ( + MethodChatCreate ChatCreateRPCMethod = "chat.create" +) + +type ChatDeleteRPCMethod string + +const ( + MethodChatDelete ChatDeleteRPCMethod = "chat.delete" +) + +type ChatInviteRPCMethod string + +const ( + MethodChatInvite ChatInviteRPCMethod = "chat.invite" +) + +type ChatMessageRPCMethod string + +const ( + MethodChatMessage ChatMessageRPCMethod = "chat.message" +) + +type ChatReactRPCMethod string + +const ( + MethodChatReact ChatReactRPCMethod = "chat.react" +) + +type ChatReadRPCMethod string + +const ( + MethodChatRead ChatReadRPCMethod = "chat.read" +) + +type ChatBlockRPCMethod string + +const ( + MethodChatBlock ChatBlockRPCMethod = "chat.block" +) + +type ChatUnblockRPCMethod string + +const ( + MethodChatUnblock ChatUnblockRPCMethod = "chat.unblock" +) + +type ChatPermitRPCMethod string + +const ( + MethodChatPermit ChatPermitRPCMethod = "chat.permit" +) + +// Defines who the user allows to message them +type ChatPermission string + +const ( + ChatPermissionAll ChatPermission = "all" + ChatPermissionFollowees ChatPermission = "followees" + ChatPermissionFollowers ChatPermission = "followers" + ChatPermissionNone ChatPermission = "none" + ChatPermissionTippees ChatPermission = "tippees" + ChatPermissionTippers ChatPermission = "tippers" + ChatPermissionVerified ChatPermission = "verified" +) + +type RPCMethod string + +const ( + RPCMethodChatBlast RPCMethod = "chat.blast" + RPCMethodChatBlock RPCMethod = "chat.block" + RPCMethodChatCreate RPCMethod = "chat.create" + RPCMethodChatDelete RPCMethod = "chat.delete" + RPCMethodChatInvite RPCMethod = "chat.invite" + RPCMethodChatMessage RPCMethod = "chat.message" + RPCMethodChatPermit RPCMethod = "chat.permit" + RPCMethodChatReact RPCMethod = "chat.react" + RPCMethodChatRead RPCMethod = "chat.read" + RPCMethodChatUnblock RPCMethod = "chat.unblock" + RPCMethodUserValidateCanChat RPCMethod = "user.validate_can_chat" +) diff --git a/api/comms/signed_request.go b/api/comms/signed_request.go new file mode 100644 index 00000000..f6f3be8c --- /dev/null +++ b/api/comms/signed_request.go @@ -0,0 +1,40 @@ +package comms + +import ( + "encoding/base64" + "errors" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/gofiber/fiber/v2" +) + +func ReadSignedPost(c *fiber.Ctx) ([]byte, string, error) { + if c.Method() != "POST" { + return nil, "", errors.New("readSignedPost bad method: " + c.Method()) + } + + payload := c.Body() + + sigHex := c.Get(SigHeader) + wallet, err := recoverSigningWallet(sigHex, payload) + return payload, wallet, err +} + +func recoverSigningWallet(signatureHex string, signedData []byte) (string, error) { + sig, err := base64.StdEncoding.DecodeString(signatureHex) + if err != nil { + err = errors.New("bad sig header: " + err.Error()) + return "", err + } + + // recover + hash := crypto.Keccak256Hash(signedData) + pubkey, err := crypto.SigToPub(hash[:], sig) + if err != nil { + return "", err + } + + wallet := crypto.PubkeyToAddress(*pubkey).Hex() + + return wallet, nil +} diff --git a/api/comms/validator.go b/api/comms/validator.go new file mode 100644 index 00000000..873431ad --- /dev/null +++ b/api/comms/validator.go @@ -0,0 +1,656 @@ +package comms + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "bridgerton.audius.co/api/dbv1" + "bridgerton.audius.co/config" + "bridgerton.audius.co/trashid" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" +) + +var ( + ErrMessageRateLimitExceeded = errors.New("user has exceeded the maximum number of new messages") +) + +type Validator struct { + logger *zap.Logger + pool *dbv1.DBPools + limiter *RateLimiter + aaoServer string +} + +func NewValidator(pool *dbv1.DBPools, limiter *RateLimiter, config *config.Config, logger *zap.Logger) *Validator { + // TODO: Don't hack around this for tests + if len(config.AntiAbuseOracles) == 0 && config.Env != "test" { + panic("no anti-abuse oracles configured, can't initialize comms validator") + } + aaoServer := "" + if len(config.AntiAbuseOracles) > 0 { + aaoServer = config.AntiAbuseOracles[0] + } + + return &Validator{ + pool: pool, + limiter: limiter, + aaoServer: aaoServer, + logger: logger, + } +} + +func (vtor *Validator) Validate(ctx context.Context, userId int32, rawRpc RawRPC) error { + methodName := RPCMethod(rawRpc.Method) + + // banned? + isBanned, err := vtor.isBanned(ctx, userId) + if err != nil { + return err + } + if isBanned { + return fmt.Errorf("user_id %d is banned from chat", userId) + } + + switch methodName { + case RPCMethodChatCreate: + return vtor.validateChatCreate(ctx, userId, rawRpc) + case RPCMethodChatDelete: + return vtor.validateChatDelete(userId, rawRpc) + case RPCMethodChatMessage: + return vtor.validateChatMessage(ctx, userId, rawRpc) + case RPCMethodChatReact: + return vtor.validateChatReact(vtor.pool, ctx, userId, rawRpc) + case RPCMethodChatRead: + return vtor.validateChatRead(userId, rawRpc) + case RPCMethodChatPermit: + return vtor.validateChatPermit(userId, rawRpc) + case RPCMethodChatBlock: + return vtor.validateChatBlock(userId, rawRpc) + case RPCMethodChatUnblock: + return vtor.validateChatUnblock(userId, rawRpc) + default: + vtor.logger.Debug("no validator for " + rawRpc.Method) + } + + return nil +} + +func (vtor *Validator) isBanned(ctx context.Context, userId int32) (bool, error) { + isBanned := false + err := vtor.pool.QueryRow(ctx, "select count(user_id) = 1 from chat_ban where user_id = $1 and is_banned = true", userId).Scan(&isBanned) + if err != nil { + return false, err + } + return isBanned, nil +} + +func (vtor *Validator) validateChatCreate(ctx context.Context, userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatCreateRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate chatId does not already exist + query := "select count(*) from chat where chat_id = $1;" + var chatCount int + err = vtor.pool.QueryRow(ctx, query, params.ChatID).Scan(&chatCount) + if err != nil { + return err + } + if chatCount != 0 { + return errors.New("Chat already exists") + } + + if len(params.Invites) != 2 { + return errors.New("Chat must have 2 members") + } + + user1, err := trashid.DecodeHashId(params.Invites[0].UserID) + if err != nil { + return err + } + user2, err := trashid.DecodeHashId(params.Invites[1].UserID) + if err != nil { + return err + } + + receiver := int32(user1) + if receiver == userId { + receiver = int32(user2) + } + + // Check that the creator is non-abusive + err = validateSenderPassesAbuseCheck(vtor.pool, ctx, vtor.logger, userId, vtor.aaoServer) + if err != nil { + return err + } + + // if recipient is creating a chat from a blast + // we ignore the receiver's inbox settings + // because receiver has sent a blast to this user. + { + hasBlast, err := hasNewBlastFromUser(vtor.pool, ctx, userId, receiver) + if err != nil { + return err + } + if hasBlast { + return nil + } + } + + // validate receiver permits chats from sender + err = validatePermissions(vtor.pool, ctx, userId, receiver) + if err != nil { + return err + } + + // validate does not exceed new chat rate limit for any invited users + var users []int32 + for _, invite := range params.Invites { + userId, err := trashid.DecodeHashId(invite.UserID) + if err != nil { + return err + } + users = append(users, int32(userId)) + } + err = vtor.validateNewChatRateLimit(vtor.pool, ctx, users) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatMessage(ctx context.Context, userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatMessageRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + // validate not blocked and can chat according to receiver's inbox permission settings + err = validatePermittedToMessage(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + // validate does not exceed new message rate limit + err = vtor.validateNewMessageRateLimit(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatReact(pool *dbv1.DBPools, ctx context.Context, userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatReactRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + // validate message exists in chat + var exists bool + err = pool.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM chat_message WHERE chat_id = $1 AND message_id = $2)`, params.ChatID, params.MessageID).Scan(&exists) + if err != nil { + return err + } + if !exists { + return errors.New("message does not exist in chat") + } + + // validate not blocked and can chat according to receiver's inbox permission settings + err = validatePermittedToMessage(vtor.pool, ctx, userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatRead(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatReadRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, context.Background(), userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatPermit(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatPermitRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatBlock(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatBlockRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + return nil +} + +func (vtor *Validator) validateChatUnblock(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatBlockRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate that params.UserID is currently blocked by userId + blockeeUserId, err := trashid.DecodeHashId(params.UserID) + if err != nil { + return err + } + + var exists bool + err = vtor.pool.QueryRow(context.Background(), ` + select exists( + select 1 from chat_blocked_users + where blocker_user_id = $1 and blockee_user_id = $2 + ) + `, userId, blockeeUserId).Scan(&exists) + if err != nil { + return err + } + if !exists { + return errors.New("user is not blocked") + } + + return nil +} + +func (vtor *Validator) validateChatDelete(userId int32, rpc RawRPC) error { + // validate rpc.params valid + var params ChatDeleteRPCParams + err := json.Unmarshal(rpc.Params, ¶ms) + if err != nil { + return err + } + + // validate userId is a member of chatId in good standing + err = validateChatMembership(vtor.pool, context.Background(), userId, params.ChatID) + if err != nil { + return err + } + + return nil +} + +// Calculate cursor from rate limit timeframe +func (vtor *Validator) calculateRateLimitCursor(timeframe int) time.Time { + return time.Now().UTC().Add(-time.Hour * time.Duration(timeframe)) +} + +func (vtor *Validator) validateNewChatRateLimit(pool *dbv1.DBPools, ctx context.Context, users []int32) error { + var err error + + // rate_limit_seconds + + limiter := vtor.limiter + timeframe := limiter.Get(RateLimitTimeframeHours) + + // Max num of new chats permitted per timeframe + maxNumChats := limiter.Get(RateLimitMaxNumNewChats) + + cursor := vtor.calculateRateLimitCursor(timeframe) + + // Build the query with proper placeholders for the IN clause + query := ` + WITH counts AS ( + SELECT COUNT(*) AS count + FROM chat + JOIN chat_member on chat.chat_id = chat_member.chat_id + WHERE chat_member.user_id = ANY($1) AND chat.created_at > $2 + GROUP BY chat_member.user_id + ) + SELECT COALESCE(MAX(count), 0) FROM counts; + ` + + var numChats int + err = pool.QueryRow(ctx, query, users, cursor).Scan(&numChats) + if err != nil { + return err + } + if numChats >= maxNumChats { + vtor.logger.Info("hit rate limit (new chats)", zap.Any("users", users)) + return errors.New("An invited user has exceeded the maximum number of new chats") + } + + return nil +} + +func (vtor *Validator) validateNewMessageRateLimit(pool *dbv1.DBPools, ctx context.Context, userId int32, chatId string) error { + var err error + + // BurstRateLimit + { + query := ` + select + coalesce(sum(case when created_at > now() - interval '1 second' then 1 else 0 end), 0) as s1, + coalesce(sum(case when created_at > now() - interval '10 seconds' then 1 else 0 end), 0) as s10, + coalesce(sum(case when created_at > now() - interval '60 seconds' then 1 else 0 end), 0) as s60 + from chat_message + where user_id = $1 + and created_at > now() - interval '60 seconds'; + ` + var s1, s10, s60 int64 + err = pool.QueryRow(ctx, query, userId).Scan(&s1, &s10, &s60) + if err != nil { + vtor.logger.Error("burst rate limit query failed", zap.Error(err)) + } + + // 10 per second in last second + if s1 > 10 { + vtor.logger.Warn("message rate limit exceeded", zap.String("bucket", "1s"), zap.Int32("user_id", userId), zap.Int64("count", s1)) + return ErrMessageRateLimitExceeded + + } + + // 7 per second for last 10 seconds + if s10 > 70 { + vtor.logger.Warn("message rate limit exceeded", zap.String("bucket", "10s"), zap.Int32("user_id", userId), zap.Int64("count", s10)) + return ErrMessageRateLimitExceeded + } + + // 5 per second for last 60 seconds + if s60 > 300 { + vtor.logger.Warn("message rate limit exceeded", zap.String("bucket", "60s"), zap.Int32("user_id", userId), zap.Int64("count", s60)) + return ErrMessageRateLimitExceeded + } + } + + limiter := vtor.limiter + timeframe := limiter.Get(RateLimitTimeframeHours) + + // Max number of new messages permitted per timeframe + maxNumMessages := limiter.Get(RateLimitMaxNumMessages) + + // Max number of new messages permitted per recipient (chat) per timeframe + maxNumMessagesPerRecipient := limiter.Get(RateLimitMaxNumMessagesPerRecipient) + + // Cursor for rate limit timeframe + cursor := vtor.calculateRateLimitCursor(timeframe) + + // Check total message count and max messages per chat + query := ` + WITH counts_per_chat AS ( + SELECT COUNT(*) + FROM chat_message + WHERE user_id = $1 and created_at > $2 + GROUP BY chat_id + ) + SELECT COALESCE(SUM(count), 0) AS total_count, COALESCE(MAX(count), 0) as max_count_per_chat FROM counts_per_chat; + ` + + var totalCount, maxCountPerChat int + err = pool.QueryRow(ctx, query, userId, cursor).Scan(&totalCount, &maxCountPerChat) + if err != nil { + return err + } + if totalCount >= maxNumMessages || maxCountPerChat >= maxNumMessagesPerRecipient { + if totalCount >= maxNumMessages { + vtor.logger.Info("hit rate limit (total count new messages)", zap.Int32("user", userId), zap.String("chat", chatId)) + } + if maxCountPerChat >= maxNumMessagesPerRecipient { + vtor.logger.Info("hit rate limit (new messages per recipient)", zap.Int32("user", userId), zap.String("chat", chatId)) + } + return ErrMessageRateLimitExceeded + } + + return nil +} + +func validateChatMembership(pool *dbv1.DBPools, ctx context.Context, userId int32, chatId string) error { + var exists bool + err := pool.QueryRow(ctx, `select exists(select 1 from chat_member where user_id = $1 and chat_id = $2)`, userId, chatId).Scan(&exists) + if err != nil { + return err + } + if !exists { + return errors.New("user is not a member of this chat") + } + return nil +} + +func validatePermissions(pool *dbv1.DBPools, ctx context.Context, sender int32, receiver int32) error { + permissionFailure := errors.New("Not permitted to send messages to this user") + + ok := false + err := pool.QueryRow(ctx, `select chat_allowed($1, $2)`, sender, receiver).Scan(&ok) + if err != nil { + return err + } + if !ok { + return permissionFailure + } + return nil + +} + +func validatePermittedToMessage(pool *dbv1.DBPools, ctx context.Context, userId int32, chatId string) error { + // Single query that validates: + // 1. Chat has exactly 2 members + // 2. User is a member of the chat + // 3. User has permission to message the other member + query := ` + WITH chat_members AS ( + SELECT user_id + FROM chat_member + WHERE chat_id = $1 + ), + member_count AS ( + SELECT COUNT(*) as count + FROM chat_members + ), + other_member AS ( + SELECT user_id + FROM chat_members + WHERE user_id != $2 + ) + SELECT + CASE + WHEN mc.count != 2 THEN false + WHEN NOT EXISTS (SELECT 1 FROM chat_members WHERE user_id = $2) THEN false + WHEN NOT chat_allowed($2, om.user_id) THEN false + ELSE true + END as is_permitted + FROM member_count mc + CROSS JOIN other_member om + ` + + var isPermitted bool + err := pool.QueryRow(ctx, query, chatId, userId).Scan(&isPermitted) + if err != nil { + if err == pgx.ErrNoRows { + return errors.New("Chat must have 2 members") + } + return err + } + + if !isPermitted { + return errors.New("Not permitted to send messages to this user") + } + + return nil +} + +var ErrAttestationFailed = errors.New("attestation failed") + +// TODO: Better AAO usage that corresponds to the claim rewards code +func validateSenderPassesAbuseCheck(pool *dbv1.DBPools, ctx context.Context, logger *zap.Logger, userId int32, aaoServer string) error { + if aaoServer == "" { + return nil + } + // Keeping this somewhat opaque as it gets sent to client + var handle string + err := pool.QueryRow(ctx, `SELECT handle FROM users WHERE user_id = $1`, userId).Scan(&handle) + if err != nil { + if err == pgx.ErrNoRows { + return fmt.Errorf("user %d not found", userId) + } + return err + } + + url := fmt.Sprintf("%s/attestation/%s", aaoServer, handle) + // Dummy challenge for now to mitigate + requestBody := []byte(`{ "challengeId": "x", "challengeSpecifier": "x", "amount": 0 }`) + resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + if err != nil { + logger.Error("Error checking user attestation", zap.Error(err), zap.String("handle", handle)) + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + logger.Warn("User failed AAO check", zap.Int32("userId", userId), zap.Int("status", resp.StatusCode), zap.String("aaoServer", aaoServer)) + return ErrAttestationFailed + } + return nil +} + +// hasNewBlastFromUser efficiently checks if a new blast exists from a specific user +// without fetching all blast data. Returns true if a valid blast exists, false otherwise. +func hasNewBlastFromUser(pool *dbv1.DBPools, ctx context.Context, userID int32, fromUserID int32) (bool, error) { + // Construct the expected chat ID for this user pair + expectedChatID := trashid.ChatID(int(userID), int(fromUserID)) + + // This query checks for the existence of a new blast from a specific user + // using the same logic as GetNewBlasts but optimized for existence check + var hasNewBlast = ` + with + last_permission_change as ( + select max(t) as t from ( + select updated_at as t from chat_permissions where user_id = $1 + union + select created_at as t from chat_blocked_users where blocker_user_id = $1 + union + select to_timestamp(0) + ) as timestamp_subquery + ) + select exists( + select 1 + from chat_blast blast + where + blast.from_user_id = $2 + and blast.created_at > (select t from last_permission_change) + and chat_allowed(blast.from_user_id, $1) + and not exists ( + select 1 from chat_member cm + where cm.user_id = $1 and cm.chat_id = $3 + ) + and ( + -- follower_audience + (blast.audience = 'follower_audience' and exists ( + SELECT 1 + FROM follows + WHERE follows.followee_user_id = blast.from_user_id + AND follows.follower_user_id = $1 + AND follows.is_delete = false + AND follows.created_at < blast.created_at + )) + OR + -- tipper_audience + (blast.audience = 'tipper_audience' and exists ( + SELECT 1 + FROM user_tips tip + WHERE receiver_user_id = blast.from_user_id + AND sender_user_id = $1 + AND tip.created_at < blast.created_at + )) + OR + -- remixer_audience + (blast.audience = 'remixer_audience' and exists ( + SELECT 1 + FROM tracks t + JOIN remixes ON remixes.child_track_id = t.track_id + JOIN tracks og ON remixes.parent_track_id = og.track_id + WHERE og.owner_id = blast.from_user_id + AND t.owner_id = $1 + AND ( + blast.audience_content_id IS NULL + OR ( + blast.audience_content_type = 'track' + AND blast.audience_content_id = og.track_id + ) + ) + )) + OR + -- customer_audience + (blast.audience = 'customer_audience' and exists ( + SELECT 1 + FROM usdc_purchases p + WHERE p.seller_user_id = blast.from_user_id + AND p.buyer_user_id = $1 + AND ( + blast.audience_content_id IS NULL + OR ( + blast.audience_content_type = p.content_type::text + AND blast.audience_content_id = p.content_id + ) + ) + )) + OR + -- coin_holder_audience + (blast.audience = 'coin_holder_audience' and exists ( + SELECT 1 + FROM artist_coins ac + JOIN sol_user_balances sub ON sub.mint = ac.mint + WHERE ac.user_id = blast.from_user_id + AND sub.user_id = $1 + AND sub.balance > 0 + -- TODO: PE-6663 This isn't entirely correct yet, need to check "time of most recent membership" + AND sub.created_at < blast.created_at + )) + ) + )` + + var exists bool + err := pool.QueryRow(ctx, hasNewBlast, userID, fromUserID, expectedChatID).Scan(&exists) + if err != nil { + return false, err + } + + return exists, nil +} diff --git a/api/comms_mutate.go b/api/comms_mutate.go new file mode 100644 index 00000000..e5afa369 --- /dev/null +++ b/api/comms_mutate.go @@ -0,0 +1,61 @@ +package api + +import ( + "encoding/json" + "errors" + "time" + + comms "bridgerton.audius.co/api/comms" + + "github.com/gofiber/fiber/v2" + "go.uber.org/zap" +) + +func (app *ApiServer) mutateChat(c *fiber.Ctx) error { + payload, wallet, err := comms.ReadSignedPost(c) + if err != nil { + return fiber.NewError(fiber.StatusBadRequest, "bad request: "+err.Error()) + } + + // unmarshal RPC and call validator + var rawRpc comms.RawRPC + err = json.Unmarshal(payload, &rawRpc) + if err != nil { + return fiber.NewError(fiber.StatusBadRequest, "bad request: "+err.Error()) + } + + rpcLog := &comms.RpcLog{ + RelayedBy: "bridge", + RelayedAt: time.Now(), + FromWallet: wallet, + Rpc: payload, + Sig: c.Get(comms.SigHeader), + } + + userId, err := app.getUserIDFromWallet(c.Context(), wallet) + if err != nil { + return err + } + + err = app.commsRpcProcessor.Validate(c.Context(), int32(userId), rawRpc) + if err != nil { + if errors.Is(err, comms.ErrAttestationFailed) { + return fiber.NewError( + fiber.StatusForbidden, + "forbidden to make this request: "+err.Error(), + ) + } + return fiber.NewError( + fiber.StatusBadRequest, + "bad request: "+err.Error(), + ) + } + + err = app.commsRpcProcessor.Apply(c.Context(), rpcLog) + if err != nil { + app.logger.Warn("comms rpc apply failed", zap.String("payload", string(payload)), zap.String("wallet", wallet), zap.Error(err)) + return err + } + app.logger.Debug("comms rpc apply succeeded", zap.String("payload", string(payload)), zap.String("wallet", wallet), zap.Bool("relay", true)) + return c.JSON(rpcLog) +} diff --git a/api/server.go b/api/server.go index 0db5b7a7..b1356274 100644 --- a/api/server.go +++ b/api/server.go @@ -12,6 +12,7 @@ import ( "time" "bridgerton.audius.co/api/birdeye" + comms "bridgerton.audius.co/api/comms" "bridgerton.audius.co/api/dbv1" "bridgerton.audius.co/config" "bridgerton.audius.co/esindexer" @@ -163,6 +164,11 @@ func NewApiServer(config config.Config) *ApiServer { metricsCollector = NewMetricsCollector(logger, writePool) } + commsRpcProcessor, err := comms.NewProcessor(pool, writePool, &config, logger) + if err != nil { + panic(err) + } + app := &ApiServer{ App: fiber.New(fiber.Config{ JSONEncoder: json.Marshal, @@ -171,6 +177,7 @@ func NewApiServer(config config.Config) *ApiServer { ReadBufferSize: 32_768, UnescapePath: true, }), + commsRpcProcessor: commsRpcProcessor, env: config.Env, skipAuthCheck: skipAuthCheck, pool: pool, @@ -289,6 +296,7 @@ func NewApiServer(config config.Config) *ApiServer { app.Use("/v1/full/tracks/best_new_releases", BalancerForward(config.PythonUpstreams)) app.Use("/v1/full/tracks/most_loved", BalancerForward(config.PythonUpstreams)) app.Use("/v1/full/tracks/remixables", BalancerForward(config.PythonUpstreams)) + app.Use("/comms/chats/ws", BalancerForward(config.PythonUpstreams)) } v1 := app.Group("/v1") @@ -500,6 +508,8 @@ func NewApiServer(config config.Config) *ApiServer { comms.Get("/blasts", app.getNewBlasts) + comms.Post("/mutate", app.mutateChat) + // Block confirmation app.Get("/block_confirmation", app.BlockConfirmation) app.Get("/block-confirmation", app.BlockConfirmation) @@ -553,6 +563,7 @@ type BirdeyeClient interface { type ApiServer struct { *fiber.App + commsRpcProcessor *comms.RPCProcessor pool *dbv1.DBPools writePool *pgxpool.Pool queries *dbv1.Queries diff --git a/trashid/chatid.go b/trashid/chatid.go index 2109da2e..6716f4aa 100644 --- a/trashid/chatid.go +++ b/trashid/chatid.go @@ -2,6 +2,7 @@ package trashid import "fmt" +// TODO: Handle errors here if we can't encode the IDs // ChatID return a encodedUser1:encodedUser2 ID where encodedUser1 is < encodedUser2 // which is the convention used to make chat IDs deterministic. // See makeChatId in SDK: packages/common/src/store/pages/chat/utils.ts @@ -14,3 +15,8 @@ func ChatID(id1, id2 int) string { } return chatId } + +// Returns a unique Message ID for a blast message in a chat. +func BlastMessageID(blastID, chatID string) string { + return blastID + chatID +}