Skip to content

Commit

Permalink
messagix: make Client nil-safe
Browse files Browse the repository at this point in the history
Closes #117
  • Loading branch information
tulir committed Feb 12, 2025
1 parent ecf4a00 commit 579abac
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 67 deletions.
9 changes: 6 additions & 3 deletions pkg/connector/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (m *MetaClient) requestMoreHistory(ctx context.Context, threadID, minTimest
ReferenceTimestampMs: minTimestampMS,
ReferenceMessageId: minMessageID,
SyncGroup: 1,
Cursor: m.Client.SyncManager.GetCursor(1),
Cursor: m.Client.GetCursor(1),
})
zerolog.Ctx(ctx).Trace().
Int64("thread_id", threadID).
Expand Down Expand Up @@ -172,7 +172,7 @@ func (m *MetaClient) removeBackfillCollector(threadID int64, collector *Backfill
}

func (m *MetaClient) FetchMessages(ctx context.Context, params bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) {
if m.Client == nil || m.Client.SyncManager == nil {
if m.Client == nil {
return nil, bridgev2.ErrNotLoggedIn
}
if params.Portal.Metadata.(*metaid.PortalMetadata).ThreadType == table.ENCRYPTED_OVER_WA_GROUP {
Expand Down Expand Up @@ -231,7 +231,10 @@ func (m *MetaClient) FetchMessages(ctx context.Context, params bridgev2.FetchMes
if !m.addBackfillCollector(threadID, collector) {
return nil, fmt.Errorf("backfill collector already exists for thread %d", threadID)
}
m.requestMoreHistory(ctx, threadID, oldestMessageTS, oldestMessageID)
if !m.requestMoreHistory(ctx, threadID, oldestMessageTS, oldestMessageID) {
m.removeBackfillCollector(threadID, collector)
return nil, fmt.Errorf("failed to request more history for thread %d", threadID)
}
select {
case <-doneCh:
upsert = collector.UpsertMessages
Expand Down
9 changes: 7 additions & 2 deletions pkg/connector/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ func (m *MetaClient) connectE2EE() error {
return fmt.Errorf("failed to save device ID to user login: %w", err)
}
}
m.E2EEClient = m.Client.PrepareE2EEClient()
m.E2EEClient, err = m.Client.PrepareE2EEClient()
if err != nil {
return fmt.Errorf("failed to prepare e2ee client: %w", err)
}
m.E2EEClient.AddEventHandler(m.e2eeEventHandler)
err = m.E2EEClient.Connect()
if err != nil {
Expand All @@ -373,10 +376,12 @@ func (m *MetaClient) Disconnect() {
(*stopConnectAttempt)()
}
if cli := m.Client; cli != nil {
cli.SetEventHandler(nil)
cli.Disconnect()
m.Client = nil
}
if ecli := m.E2EEClient; ecli != nil {
ecli.RemoveEventHandlers()
ecli.Disconnect()
m.E2EEClient = nil
}
Expand All @@ -391,7 +396,7 @@ func (m *MetaClient) Disconnect() {
}

func (m *MetaClient) IsLoggedIn() bool {
return m.Client != nil && m.Client.SyncManager != nil
return m.Client.IsAuthenticated()
}

func (m *MetaClient) IsThisUser(ctx context.Context, userID networkid.UserID) bool {
Expand Down
67 changes: 45 additions & 22 deletions pkg/messagix/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net"
"net/http"
"net/url"
"os"
"slices"
"strconv"
"sync"
Expand All @@ -18,6 +17,7 @@ import (

"github.com/google/go-querystring/query"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow"
"go.mau.fi/whatsmeow/store"
"golang.org/x/net/proxy"

Expand All @@ -43,6 +43,8 @@ const SecCHMobile = "?0"
const SecCHModel = `""`
const SecCHPrefersColorScheme = "light"

var ErrClientIsNil = whatsmeow.ErrClientIsNil

type EventHandler func(evt interface{})
type Client struct {
Instagram *InstagramMethods
Expand All @@ -54,7 +56,7 @@ type Client struct {
socket *Socket
eventHandler EventHandler
configs *Configs
SyncManager *SyncManager
syncManager *SyncManager

cookies *cookies.Cookies
httpProxy func(*http.Request) (*url.URL, error)
Expand Down Expand Up @@ -124,7 +126,9 @@ func NewClient(cookies *cookies.Cookies, logger zerolog.Logger) *Client {
}

func (c *Client) LoadMessagesPage() (types.UserInfo, *table.LSTable, error) {
if !c.cookies.IsLoggedIn() {
if c == nil {
return nil, nil, ErrClientIsNil
} else if !c.cookies.IsLoggedIn() {
return nil, nil, fmt.Errorf("can't load messages page without being authenticated")
}

Expand All @@ -134,7 +138,7 @@ func (c *Client) LoadMessagesPage() (types.UserInfo, *table.LSTable, error) {
return nil, nil, fmt.Errorf("failed to load inbox: %w", err)
}

c.SyncManager = c.NewSyncManager()
c.syncManager = c.newSyncManager()
ls, err := c.configs.SetupConfigs(moduleLoader.LS)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -175,6 +179,9 @@ func (c *Client) configurePlatformClient() {
}

func (c *Client) SetProxy(proxyAddr string) error {
if c == nil {
return ErrClientIsNil
}
proxyParsed, err := url.Parse(proxyAddr)
if err != nil {
return err
Expand Down Expand Up @@ -203,8 +210,14 @@ func (c *Client) SetEventHandler(handler EventHandler) {
c.eventHandler = handler
}

func (c *Client) handleEvent(evt any) {
if c.eventHandler != nil {
c.eventHandler(evt)
}
}

func (c *Client) UpdateProxy(reason string) bool {
if c.GetNewProxy == nil {
if c == nil || c.GetNewProxy == nil {
return true
}
if proxyAddr, err := c.GetNewProxy(reason); err != nil {
Expand All @@ -218,7 +231,9 @@ func (c *Client) UpdateProxy(reason string) bool {
}

func (c *Client) Connect() error {
if err := c.socket.CanConnect(); err != nil {
if c == nil {
return ErrClientIsNil
} else if err := c.socket.CanConnect(); err != nil {
return err
}
ctx, cancel := context.WithCancel(context.TODO())
Expand All @@ -241,11 +256,11 @@ func (c *Client) Connect() error {
errors.Is(err, CONNECTION_REFUSED_BAD_USERNAME_OR_PASSWORD) ||
// TODO server unavailable may mean a challenge state, should be checked somehow
errors.Is(err, CONNECTION_REFUSED_SERVER_UNAVAILABLE) {
c.eventHandler(&Event_PermanentError{Err: err})
c.handleEvent(&Event_PermanentError{Err: err})
return
}
connectionAttempts += 1
c.eventHandler(&Event_SocketError{Err: err, ConnectionAttempts: connectionAttempts})
c.handleEvent(&Event_SocketError{Err: err, ConnectionAttempts: connectionAttempts})
if time.Since(connectStart) > 2*time.Minute {
reconnectIn = 2 * time.Second
} else {
Expand All @@ -271,23 +286,17 @@ func (c *Client) Connect() error {
}

func (c *Client) Disconnect() {
if c == nil {
return
}
if fn := c.stopCurrentConnection.Load(); fn != nil {
(*fn)()
}
c.socket.Disconnect()
}

func (c *Client) SaveSession(path string) error {
jsonBytes, err := json.Marshal(c.cookies)
if err != nil {
return err
}

return os.WriteFile(path, jsonBytes, os.ModePerm)
}

func (c *Client) IsConnected() bool {
return c.socket.conn != nil
return c != nil && c.socket.conn != nil
}

func (c *Client) sendCookieConsent(jsDatr string) error {
Expand All @@ -305,7 +314,7 @@ func (c *Client) sendCookieConsent(jsDatr string) error {
h.Set("origin", c.getEndpoint("base_url"))
h.Set("cookie", "_js_datr="+jsDatr)
h.Set("referer", c.getEndpoint("login_page"))
q := c.NewHttpQuery()
q := c.newHTTPQuery()
q.AcceptOnlyEssential = "false"
payloadQuery = q
} else {
Expand Down Expand Up @@ -368,17 +377,22 @@ func (c *Client) getEndpointForThreadID(threadID int64) string {
}

func (c *Client) IsAuthenticated() bool {
if c == nil {
return false
}
var isAuthenticated bool
if c.Platform.IsMessenger() {
isAuthenticated = c.configs.browserConfigTable.CurrentUserInitialData.AccountID != "0"
} else {
isAuthenticated = c.configs.browserConfigTable.PolarisViewer.ID != ""
}
return isAuthenticated
return isAuthenticated && c.syncManager != nil
}

func (c *Client) GetCurrentAccount() (types.UserInfo, error) {
if !c.IsAuthenticated() {
if c == nil {
return nil, ErrClientIsNil
} else if !c.IsAuthenticated() {
return nil, fmt.Errorf("messagix-client: not yet authenticated")
}

Expand All @@ -389,7 +403,7 @@ func (c *Client) GetCurrentAccount() (types.UserInfo, error) {
}
}

func (c *Client) GetTaskId() int {
func (c *Client) getTaskID() int {
c.taskMutex.Lock()
defer c.taskMutex.Unlock()
id := 0
Expand All @@ -402,19 +416,28 @@ func (c *Client) GetTaskId() int {
}

func (c *Client) EnableSendingMessages() {
if c == nil {
return
}
c.sendMessagesCond.L.Lock()
c.canSendMessages = true
c.sendMessagesCond.Broadcast()
c.sendMessagesCond.L.Unlock()
}

func (c *Client) disableSendingMessages() {
if c == nil {
return
}
c.sendMessagesCond.L.Lock()
c.canSendMessages = false
c.sendMessagesCond.L.Unlock()
}

func (c *Client) WaitUntilCanSendMessages(timeout time.Duration) error {
if c == nil {
return ErrClientIsNil
}
timer := time.NewTimer(timeout)
defer timer.Stop()

Expand Down
10 changes: 5 additions & 5 deletions pkg/messagix/configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ func (c *Configs) SetupConfigs(ls *table.LSTable) (*table.LSTable, error) {
} else {
c.client.socket.broker = c.browserConfigTable.MqttWebConfig.Endpoint
}
c.client.SyncManager.syncParams = &c.browserConfigTable.LSPlatformMessengerSyncParams
c.client.syncManager.syncParams = &c.browserConfigTable.LSPlatformMessengerSyncParams
if len(ls.LSExecuteFinallyBlockForSyncTransaction) == 0 {
c.client.Logger.Warn().Msg("Syncing initial data via graphql")
err := c.client.SyncManager.UpdateDatabaseSyncParams(
err := c.client.syncManager.UpdateDatabaseSyncParams(
[]*socket.QueryMetadata{
{DatabaseId: 1, SendSyncParams: true, LastAppliedCursor: nil, SyncChannel: socket.MailBox},
{DatabaseId: 2, SendSyncParams: true, LastAppliedCursor: nil, SyncChannel: socket.Contact},
Expand All @@ -64,18 +64,18 @@ func (c *Configs) SetupConfigs(ls *table.LSTable) (*table.LSTable, error) {
return ls, fmt.Errorf("failed to update sync params for databases: 1, 2, 95: %w", err)
}

ls, err = c.client.SyncManager.SyncDataGraphQL([]int64{1, 2, 95})
ls, err = c.client.syncManager.SyncDataGraphQL([]int64{1, 2, 95})
if err != nil {
return ls, fmt.Errorf("failed to sync data via graphql for databases: 1, 2, 95: %w", err)
}
} else {
if len(ls.LSUpsertSyncGroupThreadsRange) > 0 {
err := c.client.SyncManager.updateThreadRanges(ls.LSUpsertSyncGroupThreadsRange)
err := c.client.syncManager.updateThreadRanges(ls.LSUpsertSyncGroupThreadsRange)
if err != nil {
return ls, fmt.Errorf("failed to update thread ranges from js module data: %w", err)
}
}
err := c.client.SyncManager.SyncTransactions(ls.LSExecuteFirstBlockForSyncTransaction)
err := c.client.syncManager.SyncTransactions(ls.LSExecuteFirstBlockForSyncTransaction)
if err != nil {
return ls, fmt.Errorf("failed to sync transactions from js module data with syncManager: %w", err)
}
Expand Down
13 changes: 9 additions & 4 deletions pkg/messagix/e2ee-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ import (
"go.mau.fi/mautrix-meta/pkg/messagix/types"
)

func (c *Client) PrepareE2EEClient() *whatsmeow.Client {
if c.device == nil {
panic("PrepareE2EEClient called without device")
func (c *Client) PrepareE2EEClient() (*whatsmeow.Client, error) {
if c == nil {
return nil, ErrClientIsNil
} else if c.device == nil {
return nil, fmt.Errorf("PrepareE2EEClient called without device")
}
e2eeClient := whatsmeow.NewClient(c.device, waLog.Zerolog(c.Logger.With().Str("component", "whatsmeow").Logger()))
e2eeClient.GetClientPayload = c.getClientPayload
Expand All @@ -43,7 +45,7 @@ func (c *Client) PrepareE2EEClient() *whatsmeow.Client {
WebsocketURL: c.getEndpoint("e2ee_ws_url"),
}
e2eeClient.RefreshCAT = c.refreshCAT
return e2eeClient
return e2eeClient, nil
}

type refreshCATResponseGraphQL struct {
Expand All @@ -60,6 +62,9 @@ type refreshCATResponseGraphQL struct {
}

func (c *Client) refreshCAT() error {
if c == nil {
return ErrClientIsNil
}
c.catRefreshLock.Lock()
defer c.catRefreshLock.Unlock()
currentExpiration := time.Unix(c.configs.browserConfigTable.MessengerWebInitData.CryptoAuthToken.ExpirationTimeInSeconds, 0)
Expand Down
8 changes: 6 additions & 2 deletions pkg/messagix/e2ee-register.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,15 @@ func sliceifyIdentities(identities [][32]byte) [][]byte {
}

func (c *Client) SetDevice(dev *store.Device) {
c.device = dev
if c != nil {
c.device = dev
}
}

func (c *Client) RegisterE2EE(ctx context.Context, fbid int64) error {
if c.device == nil {
if c == nil {
return ErrClientIsNil
} else if c.device == nil {
return fmt.Errorf("cannot register for E2EE without a device")
}
if c.device.FacebookUUID == uuid.Nil {
Expand Down
Loading

0 comments on commit 579abac

Please sign in to comment.