From 579abacd7e86c9225140f5678402cd880aab4177 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Feb 2025 16:03:27 +0200 Subject: [PATCH] messagix: make Client nil-safe Closes #117 --- pkg/connector/backfill.go | 9 +++-- pkg/connector/client.go | 9 +++-- pkg/messagix/client.go | 67 +++++++++++++++++++++++------------ pkg/messagix/configs.go | 10 +++--- pkg/messagix/e2ee-client.go | 13 ++++--- pkg/messagix/e2ee-register.go | 8 +++-- pkg/messagix/events.go | 24 ++++++------- pkg/messagix/facebook.go | 2 +- pkg/messagix/graphql.go | 2 +- pkg/messagix/http.go | 5 ++- pkg/messagix/instagram.go | 2 +- pkg/messagix/mercury.go | 9 +++-- pkg/messagix/payload.go | 4 +-- pkg/messagix/socket.go | 4 +-- pkg/messagix/syncManager.go | 9 ++++- pkg/messagix/taskManager.go | 8 ++--- pkg/messagix/threads.go | 5 ++- 17 files changed, 123 insertions(+), 67 deletions(-) diff --git a/pkg/connector/backfill.go b/pkg/connector/backfill.go index fa572ff..9a38c55 100644 --- a/pkg/connector/backfill.go +++ b/pkg/connector/backfill.go @@ -132,7 +132,7 @@ func (m *MetaClient) requestMoreHistory(ctx context.Context, threadID, minTimest ReferenceTimestampMs: minTimestampMS, ReferenceMessageId: minMessageID, SyncGroup: 1, - Cursor: m.Client.SyncManager.GetCursor(1), + Cursor: m.Client.GetCursor(1), }) zerolog.Ctx(ctx).Trace(). Int64("thread_id", threadID). @@ -172,7 +172,7 @@ func (m *MetaClient) removeBackfillCollector(threadID int64, collector *Backfill } func (m *MetaClient) FetchMessages(ctx context.Context, params bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) { - if m.Client == nil || m.Client.SyncManager == nil { + if m.Client == nil { return nil, bridgev2.ErrNotLoggedIn } if params.Portal.Metadata.(*metaid.PortalMetadata).ThreadType == table.ENCRYPTED_OVER_WA_GROUP { @@ -231,7 +231,10 @@ func (m *MetaClient) FetchMessages(ctx context.Context, params bridgev2.FetchMes if !m.addBackfillCollector(threadID, collector) { return nil, fmt.Errorf("backfill collector already exists for thread %d", threadID) } - m.requestMoreHistory(ctx, threadID, oldestMessageTS, oldestMessageID) + if !m.requestMoreHistory(ctx, threadID, oldestMessageTS, oldestMessageID) { + m.removeBackfillCollector(threadID, collector) + return nil, fmt.Errorf("failed to request more history for thread %d", threadID) + } select { case <-doneCh: upsert = collector.UpsertMessages diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 91b0b2a..f9ce658 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -359,7 +359,10 @@ func (m *MetaClient) connectE2EE() error { return fmt.Errorf("failed to save device ID to user login: %w", err) } } - m.E2EEClient = m.Client.PrepareE2EEClient() + m.E2EEClient, err = m.Client.PrepareE2EEClient() + if err != nil { + return fmt.Errorf("failed to prepare e2ee client: %w", err) + } m.E2EEClient.AddEventHandler(m.e2eeEventHandler) err = m.E2EEClient.Connect() if err != nil { @@ -373,10 +376,12 @@ func (m *MetaClient) Disconnect() { (*stopConnectAttempt)() } if cli := m.Client; cli != nil { + cli.SetEventHandler(nil) cli.Disconnect() m.Client = nil } if ecli := m.E2EEClient; ecli != nil { + ecli.RemoveEventHandlers() ecli.Disconnect() m.E2EEClient = nil } @@ -391,7 +396,7 @@ func (m *MetaClient) Disconnect() { } func (m *MetaClient) IsLoggedIn() bool { - return m.Client != nil && m.Client.SyncManager != nil + return m.Client.IsAuthenticated() } func (m *MetaClient) IsThisUser(ctx context.Context, userID networkid.UserID) bool { diff --git a/pkg/messagix/client.go b/pkg/messagix/client.go index 64aaf33..612a5b6 100644 --- a/pkg/messagix/client.go +++ b/pkg/messagix/client.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "net/url" - "os" "slices" "strconv" "sync" @@ -18,6 +17,7 @@ import ( "github.com/google/go-querystring/query" "github.com/rs/zerolog" + "go.mau.fi/whatsmeow" "go.mau.fi/whatsmeow/store" "golang.org/x/net/proxy" @@ -43,6 +43,8 @@ const SecCHMobile = "?0" const SecCHModel = `""` const SecCHPrefersColorScheme = "light" +var ErrClientIsNil = whatsmeow.ErrClientIsNil + type EventHandler func(evt interface{}) type Client struct { Instagram *InstagramMethods @@ -54,7 +56,7 @@ type Client struct { socket *Socket eventHandler EventHandler configs *Configs - SyncManager *SyncManager + syncManager *SyncManager cookies *cookies.Cookies httpProxy func(*http.Request) (*url.URL, error) @@ -124,7 +126,9 @@ func NewClient(cookies *cookies.Cookies, logger zerolog.Logger) *Client { } func (c *Client) LoadMessagesPage() (types.UserInfo, *table.LSTable, error) { - if !c.cookies.IsLoggedIn() { + if c == nil { + return nil, nil, ErrClientIsNil + } else if !c.cookies.IsLoggedIn() { return nil, nil, fmt.Errorf("can't load messages page without being authenticated") } @@ -134,7 +138,7 @@ func (c *Client) LoadMessagesPage() (types.UserInfo, *table.LSTable, error) { return nil, nil, fmt.Errorf("failed to load inbox: %w", err) } - c.SyncManager = c.NewSyncManager() + c.syncManager = c.newSyncManager() ls, err := c.configs.SetupConfigs(moduleLoader.LS) if err != nil { return nil, nil, err @@ -175,6 +179,9 @@ func (c *Client) configurePlatformClient() { } func (c *Client) SetProxy(proxyAddr string) error { + if c == nil { + return ErrClientIsNil + } proxyParsed, err := url.Parse(proxyAddr) if err != nil { return err @@ -203,8 +210,14 @@ func (c *Client) SetEventHandler(handler EventHandler) { c.eventHandler = handler } +func (c *Client) handleEvent(evt any) { + if c.eventHandler != nil { + c.eventHandler(evt) + } +} + func (c *Client) UpdateProxy(reason string) bool { - if c.GetNewProxy == nil { + if c == nil || c.GetNewProxy == nil { return true } if proxyAddr, err := c.GetNewProxy(reason); err != nil { @@ -218,7 +231,9 @@ func (c *Client) UpdateProxy(reason string) bool { } func (c *Client) Connect() error { - if err := c.socket.CanConnect(); err != nil { + if c == nil { + return ErrClientIsNil + } else if err := c.socket.CanConnect(); err != nil { return err } ctx, cancel := context.WithCancel(context.TODO()) @@ -241,11 +256,11 @@ func (c *Client) Connect() error { errors.Is(err, CONNECTION_REFUSED_BAD_USERNAME_OR_PASSWORD) || // TODO server unavailable may mean a challenge state, should be checked somehow errors.Is(err, CONNECTION_REFUSED_SERVER_UNAVAILABLE) { - c.eventHandler(&Event_PermanentError{Err: err}) + c.handleEvent(&Event_PermanentError{Err: err}) return } connectionAttempts += 1 - c.eventHandler(&Event_SocketError{Err: err, ConnectionAttempts: connectionAttempts}) + c.handleEvent(&Event_SocketError{Err: err, ConnectionAttempts: connectionAttempts}) if time.Since(connectStart) > 2*time.Minute { reconnectIn = 2 * time.Second } else { @@ -271,23 +286,17 @@ func (c *Client) Connect() error { } func (c *Client) Disconnect() { + if c == nil { + return + } if fn := c.stopCurrentConnection.Load(); fn != nil { (*fn)() } c.socket.Disconnect() } -func (c *Client) SaveSession(path string) error { - jsonBytes, err := json.Marshal(c.cookies) - if err != nil { - return err - } - - return os.WriteFile(path, jsonBytes, os.ModePerm) -} - func (c *Client) IsConnected() bool { - return c.socket.conn != nil + return c != nil && c.socket.conn != nil } func (c *Client) sendCookieConsent(jsDatr string) error { @@ -305,7 +314,7 @@ func (c *Client) sendCookieConsent(jsDatr string) error { h.Set("origin", c.getEndpoint("base_url")) h.Set("cookie", "_js_datr="+jsDatr) h.Set("referer", c.getEndpoint("login_page")) - q := c.NewHttpQuery() + q := c.newHTTPQuery() q.AcceptOnlyEssential = "false" payloadQuery = q } else { @@ -368,17 +377,22 @@ func (c *Client) getEndpointForThreadID(threadID int64) string { } func (c *Client) IsAuthenticated() bool { + if c == nil { + return false + } var isAuthenticated bool if c.Platform.IsMessenger() { isAuthenticated = c.configs.browserConfigTable.CurrentUserInitialData.AccountID != "0" } else { isAuthenticated = c.configs.browserConfigTable.PolarisViewer.ID != "" } - return isAuthenticated + return isAuthenticated && c.syncManager != nil } func (c *Client) GetCurrentAccount() (types.UserInfo, error) { - if !c.IsAuthenticated() { + if c == nil { + return nil, ErrClientIsNil + } else if !c.IsAuthenticated() { return nil, fmt.Errorf("messagix-client: not yet authenticated") } @@ -389,7 +403,7 @@ func (c *Client) GetCurrentAccount() (types.UserInfo, error) { } } -func (c *Client) GetTaskId() int { +func (c *Client) getTaskID() int { c.taskMutex.Lock() defer c.taskMutex.Unlock() id := 0 @@ -402,6 +416,9 @@ func (c *Client) GetTaskId() int { } func (c *Client) EnableSendingMessages() { + if c == nil { + return + } c.sendMessagesCond.L.Lock() c.canSendMessages = true c.sendMessagesCond.Broadcast() @@ -409,12 +426,18 @@ func (c *Client) EnableSendingMessages() { } func (c *Client) disableSendingMessages() { + if c == nil { + return + } c.sendMessagesCond.L.Lock() c.canSendMessages = false c.sendMessagesCond.L.Unlock() } func (c *Client) WaitUntilCanSendMessages(timeout time.Duration) error { + if c == nil { + return ErrClientIsNil + } timer := time.NewTimer(timeout) defer timer.Stop() diff --git a/pkg/messagix/configs.go b/pkg/messagix/configs.go index b1cc6e3..a2f208e 100644 --- a/pkg/messagix/configs.go +++ b/pkg/messagix/configs.go @@ -50,10 +50,10 @@ func (c *Configs) SetupConfigs(ls *table.LSTable) (*table.LSTable, error) { } else { c.client.socket.broker = c.browserConfigTable.MqttWebConfig.Endpoint } - c.client.SyncManager.syncParams = &c.browserConfigTable.LSPlatformMessengerSyncParams + c.client.syncManager.syncParams = &c.browserConfigTable.LSPlatformMessengerSyncParams if len(ls.LSExecuteFinallyBlockForSyncTransaction) == 0 { c.client.Logger.Warn().Msg("Syncing initial data via graphql") - err := c.client.SyncManager.UpdateDatabaseSyncParams( + err := c.client.syncManager.UpdateDatabaseSyncParams( []*socket.QueryMetadata{ {DatabaseId: 1, SendSyncParams: true, LastAppliedCursor: nil, SyncChannel: socket.MailBox}, {DatabaseId: 2, SendSyncParams: true, LastAppliedCursor: nil, SyncChannel: socket.Contact}, @@ -64,18 +64,18 @@ func (c *Configs) SetupConfigs(ls *table.LSTable) (*table.LSTable, error) { return ls, fmt.Errorf("failed to update sync params for databases: 1, 2, 95: %w", err) } - ls, err = c.client.SyncManager.SyncDataGraphQL([]int64{1, 2, 95}) + ls, err = c.client.syncManager.SyncDataGraphQL([]int64{1, 2, 95}) if err != nil { return ls, fmt.Errorf("failed to sync data via graphql for databases: 1, 2, 95: %w", err) } } else { if len(ls.LSUpsertSyncGroupThreadsRange) > 0 { - err := c.client.SyncManager.updateThreadRanges(ls.LSUpsertSyncGroupThreadsRange) + err := c.client.syncManager.updateThreadRanges(ls.LSUpsertSyncGroupThreadsRange) if err != nil { return ls, fmt.Errorf("failed to update thread ranges from js module data: %w", err) } } - err := c.client.SyncManager.SyncTransactions(ls.LSExecuteFirstBlockForSyncTransaction) + err := c.client.syncManager.SyncTransactions(ls.LSExecuteFirstBlockForSyncTransaction) if err != nil { return ls, fmt.Errorf("failed to sync transactions from js module data with syncManager: %w", err) } diff --git a/pkg/messagix/e2ee-client.go b/pkg/messagix/e2ee-client.go index 99d5299..85a6c58 100644 --- a/pkg/messagix/e2ee-client.go +++ b/pkg/messagix/e2ee-client.go @@ -31,9 +31,11 @@ import ( "go.mau.fi/mautrix-meta/pkg/messagix/types" ) -func (c *Client) PrepareE2EEClient() *whatsmeow.Client { - if c.device == nil { - panic("PrepareE2EEClient called without device") +func (c *Client) PrepareE2EEClient() (*whatsmeow.Client, error) { + if c == nil { + return nil, ErrClientIsNil + } else if c.device == nil { + return nil, fmt.Errorf("PrepareE2EEClient called without device") } e2eeClient := whatsmeow.NewClient(c.device, waLog.Zerolog(c.Logger.With().Str("component", "whatsmeow").Logger())) e2eeClient.GetClientPayload = c.getClientPayload @@ -43,7 +45,7 @@ func (c *Client) PrepareE2EEClient() *whatsmeow.Client { WebsocketURL: c.getEndpoint("e2ee_ws_url"), } e2eeClient.RefreshCAT = c.refreshCAT - return e2eeClient + return e2eeClient, nil } type refreshCATResponseGraphQL struct { @@ -60,6 +62,9 @@ type refreshCATResponseGraphQL struct { } func (c *Client) refreshCAT() error { + if c == nil { + return ErrClientIsNil + } c.catRefreshLock.Lock() defer c.catRefreshLock.Unlock() currentExpiration := time.Unix(c.configs.browserConfigTable.MessengerWebInitData.CryptoAuthToken.ExpirationTimeInSeconds, 0) diff --git a/pkg/messagix/e2ee-register.go b/pkg/messagix/e2ee-register.go index 3bd29f0..4af3f8a 100644 --- a/pkg/messagix/e2ee-register.go +++ b/pkg/messagix/e2ee-register.go @@ -143,11 +143,15 @@ func sliceifyIdentities(identities [][32]byte) [][]byte { } func (c *Client) SetDevice(dev *store.Device) { - c.device = dev + if c != nil { + c.device = dev + } } func (c *Client) RegisterE2EE(ctx context.Context, fbid int64) error { - if c.device == nil { + if c == nil { + return ErrClientIsNil + } else if c.device == nil { return fmt.Errorf("cannot register for E2EE without a device") } if c.device.FacebookUUID == uuid.Nil { diff --git a/pkg/messagix/events.go b/pkg/messagix/events.go index 2cf6cec..be840ff 100644 --- a/pkg/messagix/events.go +++ b/pkg/messagix/events.go @@ -18,11 +18,11 @@ import ( func (s *Socket) handleReadyEvent(data *Event_Ready) error { if s.previouslyConnected { s.client.EnableSendingMessages() - err := s.client.SyncManager.EnsureSyncedSocket(reconnectSync[s.client.Platform]) + err := s.client.syncManager.EnsureSyncedSocket(reconnectSync[s.client.Platform]) if err != nil { return fmt.Errorf("failed to sync after reconnect: %w", err) } - s.client.eventHandler(&Event_Reconnected{}) + s.client.handleEvent(&Event_Reconnected{}) return nil } appSettingPublishJSON, err := s.newAppSettingsPublishJSON(s.client.configs.VersionId) @@ -47,14 +47,14 @@ func (s *Socket) handleReadyEvent(data *Event_Ready) error { s.client.EnableSendingMessages() - tskm := s.client.NewTaskManager() + tskm := s.client.newTaskManager() tskm.AddNewTask(&socket.FetchThreadsTask{ IsAfter: 0, ParentThreadKey: -1, ReferenceThreadKey: 0, ReferenceActivityTimestamp: 9999999999999, AdditionalPagesToFetch: 0, - Cursor: s.client.SyncManager.GetCursor(1), + Cursor: s.client.syncManager.GetCursor(1), SyncGroup: 1, }) tskm.AddNewTask(&socket.FetchThreadsTask{ @@ -66,16 +66,16 @@ func (s *Socket) handleReadyEvent(data *Event_Ready) error { SyncGroup: 95, }) - syncGroupKeyStore1 := s.client.SyncManager.getSyncGroupKeyStore(1) + syncGroupKeyStore1 := s.client.syncManager.getSyncGroupKeyStore(1) if syncGroupKeyStore1 != nil { - // syncGroupKeyStore95 := s.client.SyncManager.getSyncGroupKeyStore(95) + // syncGroupKeyStore95 := s.client.syncManager.getSyncGroupKeyStore(95) tskm.AddNewTask(&socket.FetchThreadsTask{ IsAfter: 0, ParentThreadKey: syncGroupKeyStore1.ParentThreadKey, ReferenceThreadKey: syncGroupKeyStore1.MinThreadKey, ReferenceActivityTimestamp: syncGroupKeyStore1.MinLastActivityTimestampMs, AdditionalPagesToFetch: 0, - Cursor: s.client.SyncManager.GetCursor(1), + Cursor: s.client.syncManager.GetCursor(1), SyncGroup: 1, }) tskm.AddNewTask(&socket.FetchThreadsTask{ @@ -104,13 +104,13 @@ func (s *Socket) handleReadyEvent(data *Event_Ready) error { return fmt.Errorf("failed to report app state: %w", err) } - err = s.client.SyncManager.EnsureSyncedSocket(initialSync[s.client.Platform]) + err = s.client.syncManager.EnsureSyncedSocket(initialSync[s.client.Platform]) if err != nil { return fmt.Errorf("failed to ensure db 1 is synced: %w", err) } data.client = s.client - s.client.eventHandler(data.Finish()) + s.client.handleEvent(data.Finish()) s.previouslyConnected = true return nil @@ -141,13 +141,13 @@ func (s *Socket) handlePublishResponseEvent(resp *Event_PublishResponse, qos pac Any("LSExecuteFirstBlockForSyncTransaction", resp.Table.LSExecuteFirstBlockForSyncTransaction). Any("LSUpsertSyncGroupThreadsRange", resp.Table.LSUpsertSyncGroupThreadsRange). Msg("Updating sync groups") - //err := s.client.SyncManager.SyncTransactions(transactions) - err := s.client.SyncManager.updateSyncGroupCursors(resp.Table) + //err := s.client.syncManager.SyncTransactions(transactions) + err := s.client.syncManager.updateSyncGroupCursors(resp.Table) if err != nil { s.client.Logger.Err(err).Msg("Failed to sync transactions from publish response event") } } - s.client.eventHandler(resp) + s.client.handleEvent(resp) } else { s.client.Logger.Debug().Int64("packet_id", packetId).Msg("Got unexpected lightspeed publish response") } diff --git a/pkg/messagix/facebook.go b/pkg/messagix/facebook.go index 1d5ba25..4b94777 100644 --- a/pkg/messagix/facebook.go +++ b/pkg/messagix/facebook.go @@ -89,7 +89,7 @@ func (fb *FacebookMethods) RegisterPushNotifications(endpoint string, keys PushK return err } - payload := c.NewHttpQuery() + payload := c.newHTTPQuery() payload.AppID = "1443096165982425" payload.PushEndpoint = endpoint payload.SubscriptionKeys = string(jsonKeys) diff --git a/pkg/messagix/graphql.go b/pkg/messagix/graphql.go index 03cf575..bf2870d 100644 --- a/pkg/messagix/graphql.go +++ b/pkg/messagix/graphql.go @@ -26,7 +26,7 @@ func (c *Client) makeGraphQLRequest(name string, variables interface{}) (*http.R return nil, nil, fmt.Errorf("failed to marshal graphql variables to json string: %w", err) } - payload := c.NewHttpQuery() + payload := c.newHTTPQuery() payload.FbAPICallerClass = graphQLDoc.CallerClass payload.FbAPIReqFriendlyName = graphQLDoc.FriendlyName payload.Variables = string(vBytes) diff --git a/pkg/messagix/http.go b/pkg/messagix/http.go index b0f72b1..4454d0d 100644 --- a/pkg/messagix/http.go +++ b/pkg/messagix/http.go @@ -52,7 +52,7 @@ type HttpQuery struct { Aaid string `url:"__aaid,omitempty"` } -func (c *Client) NewHttpQuery() *HttpQuery { +func (c *Client) newHTTPQuery() *HttpQuery { c.graphQLRequests++ siteConfig := c.configs.browserConfigTable.SiteData dpr := strconv.FormatFloat(siteConfig.Pr, 'g', 4, 64) @@ -143,6 +143,9 @@ func (c *Client) checkHTTPRedirect(req *http.Request, via []*http.Request) error } func (c *Client) MakeRequest(url string, method string, headers http.Header, payload []byte, contentType types.ContentType) (*http.Response, []byte, error) { + if c == nil { + return nil, nil, ErrClientIsNil + } var attempts int for { attempts++ diff --git a/pkg/messagix/instagram.go b/pkg/messagix/instagram.go index f7ee776..877b70a 100644 --- a/pkg/messagix/instagram.go +++ b/pkg/messagix/instagram.go @@ -192,7 +192,7 @@ func (ig *InstagramMethods) RegisterPushNotifications(endpoint string, keys Push } u := uuid.New() - payload := c.NewHttpQuery() + payload := c.newHTTPQuery() payload.Mid = u.String() payload.DeviceType = "web_vapid" payload.DeviceToken = endpoint diff --git a/pkg/messagix/mercury.go b/pkg/messagix/mercury.go index b449070..7776f4c 100644 --- a/pkg/messagix/mercury.go +++ b/pkg/messagix/mercury.go @@ -31,7 +31,10 @@ type WaveformData struct { } func (c *Client) SendMercuryUploadRequest(ctx context.Context, threadID int64, media *MercuryUploadMedia) (*types.MercuryUploadResponse, error) { - urlQueries := c.NewHttpQuery() + if c == nil { + return nil, ErrClientIsNil + } + urlQueries := c.newHTTPQuery() queryValues, err := query.Values(urlQueries) if err != nil { return nil, fmt.Errorf("failed to convert HttpQuery into query.Values for mercury upload: %w", err) @@ -39,7 +42,7 @@ func (c *Client) SendMercuryUploadRequest(ctx context.Context, threadID int64, m payloadQuery := queryValues.Encode() url := c.getEndpoint("media_upload") + payloadQuery - payload, contentType, err := c.NewMercuryMediaPayload(media) + payload, contentType, err := c.newMercuryMediaPayload(media) if err != nil { return nil, err } @@ -131,7 +134,7 @@ func (c *Client) parseMetadata(response *types.MercuryUploadResponse) error { } // returns payloadBytes, multipart content-type header -func (c *Client) NewMercuryMediaPayload(media *MercuryUploadMedia) ([]byte, string, error) { +func (c *Client) newMercuryMediaPayload(media *MercuryUploadMedia) ([]byte, string, error) { var mercuryPayload bytes.Buffer writer := multipart.NewWriter(&mercuryPayload) diff --git a/pkg/messagix/payload.go b/pkg/messagix/payload.go index fc12d20..387a56d 100644 --- a/pkg/messagix/payload.go +++ b/pkg/messagix/payload.go @@ -49,7 +49,7 @@ func (pb *PublishPayload) Write() ([]byte, error) { return byter.NewWriter().WriteFromStruct(pb) } -func (c *Client) NewPublishRequest(topic Topic, jsonData string, packetByte byte, packetId uint16) ([]byte, uint16, error) { +func (c *Client) newPublishRequest(topic Topic, jsonData string, packetByte byte, packetId uint16) ([]byte, uint16, error) { payload := &PublishPayload{ Topic: topic, PacketId: packetId, @@ -81,7 +81,7 @@ func (sb *SubscribePayload) Write() ([]byte, error) { return byter.NewWriter().WriteFromStruct(sb) } -func (c *Client) NewSubscribeRequest(topic Topic, qos packets.QoS) ([]byte, uint16, error) { +func (c *Client) newSubscribeRequest(topic Topic, qos packets.QoS) ([]byte, uint16, error) { packetByte := &packets.SubscribePacket{} packetId := c.socket.SafePacketId() c.socket.responseHandler.addPacketChannel(packetId) diff --git a/pkg/messagix/socket.go b/pkg/messagix/socket.go index 8579d1d..6596004 100644 --- a/pkg/messagix/socket.go +++ b/pkg/messagix/socket.go @@ -347,7 +347,7 @@ func (s *Socket) sendConnectPacket() error { } func (s *Socket) sendSubscribePacket(topic Topic, qos packets.QoS, wait bool) (*Event_SubscribeACK, error) { - subscribeRequestPayload, packetId, err := s.client.NewSubscribeRequest(topic, qos) + subscribeRequestPayload, packetId, err := s.client.newSubscribeRequest(topic, qos) if err != nil { return nil, err } @@ -372,7 +372,7 @@ func (s *Socket) sendSubscribePacket(topic Topic, qos packets.QoS, wait bool) (* } func (s *Socket) sendPublishPacket(topic Topic, jsonData string, packet *packets.PublishPacket, packetId uint16) (uint16, error) { - publishRequestPayload, packetId, err := s.client.NewPublishRequest(topic, jsonData, packet.Compress(), packetId) + publishRequestPayload, packetId, err := s.client.newPublishRequest(topic, jsonData, packet.Compress(), packetId) if err != nil { return packetId, err } diff --git a/pkg/messagix/syncManager.go b/pkg/messagix/syncManager.go index b34bb94..f5bc979 100644 --- a/pkg/messagix/syncManager.go +++ b/pkg/messagix/syncManager.go @@ -22,7 +22,7 @@ type SyncManager struct { syncParams *types.LSPlatformMessengerSyncParams } -func (c *Client) NewSyncManager() *SyncManager { +func (c *Client) newSyncManager() *SyncManager { return &SyncManager{ client: c, store: map[int64]*socket.QueryMetadata{ @@ -251,6 +251,13 @@ func (sm *SyncManager) GetCursor(db int64) string { return *database.LastAppliedCursor } +func (c *Client) GetCursor(db int64) string { + if c == nil || c.syncManager == nil { + return "" + } + return c.syncManager.GetCursor(db) +} + func (sm *SyncManager) updateThreadRanges(ranges []*table.LSUpsertSyncGroupThreadsRange) error { var err error for _, syncGroupData := range ranges { diff --git a/pkg/messagix/taskManager.go b/pkg/messagix/taskManager.go index eb60526..e970971 100644 --- a/pkg/messagix/taskManager.go +++ b/pkg/messagix/taskManager.go @@ -14,7 +14,7 @@ type TaskManager struct { traceId string } -func (c *Client) NewTaskManager() *TaskManager { +func (c *Client) newTaskManager() *TaskManager { return &TaskManager{ client: c, currTasks: make([]socket.TaskData, 0), @@ -62,13 +62,13 @@ func (tm *TaskManager) AddNewTask(task socket.Task) { Label: label, Payload: string(payloadMarshalled), QueueName: queueName, - TaskId: tm.GetTaskId(), + TaskId: tm.GetTaskID(), } tm.client.Logger.Trace().Any("label", label).Any("payload", payload).Any("queueName", queueName).Any("taskId", taskData.TaskId).Msg("Creating task") tm.currTasks = append(tm.currTasks, taskData) } -func (tm *TaskManager) GetTaskId() int64 { - return int64(tm.client.GetTaskId()) +func (tm *TaskManager) GetTaskID() int64 { + return int64(tm.client.getTaskID()) } diff --git a/pkg/messagix/threads.go b/pkg/messagix/threads.go index 01cbf6f..b528f7c 100644 --- a/pkg/messagix/threads.go +++ b/pkg/messagix/threads.go @@ -8,7 +8,10 @@ import ( ) func (c *Client) ExecuteTasks(tasks ...socket.Task) (*table.LSTable, error) { - tskm := c.NewTaskManager() + if c == nil { + return nil, ErrClientIsNil + } + tskm := c.newTaskManager() for _, task := range tasks { tskm.AddNewTask(task) }