diff --git a/src/Signal/libsignal.ts b/src/Signal/libsignal.ts index 9a2164acff5..e6426b47599 100644 --- a/src/Signal/libsignal.ts +++ b/src/Signal/libsignal.ts @@ -7,6 +7,7 @@ import type { LIDMapping, SignalAuthState, SignalKeyStoreWithTransaction } from import type { SignalRepositoryWithLIDStore } from '../Types/Signal' import { generateSignalPubKey } from '../Utils' import type { ILogger } from '../Utils/logger' +import { makeKeyedMutex } from '../Utils/make-mutex' import { isHostedLidUser, isHostedPnUser, @@ -57,11 +58,55 @@ export function makeLibSignalRepository( const parsedKeys = auth.keys as SignalKeyStoreWithTransaction const migratedSessionCache = new LRUCache({ - ttl: 3 * 24 * 60 * 60 * 1000, // 7 days + // Bounded to prevent unbounded growth for high-cardinality accounts. + max: 10_000, + ttl: 3 * 24 * 60 * 60 * 1000, ttlAutopurge: true, updateAgeOnGet: true }) + // `parsedKeys.transaction` skips its mutex when nested in another ALS + // transaction, so two outbound sends triggered from different recv + // handlers can hit the same Signal session in parallel. This mutex + // always acquires, keyed by the resolved (PN→LID) Signal address so + // PN- and LID-addressed calls converge. + const sessionMutex = makeKeyedMutex() + + // Sorted to prevent deadlocks between callers asking for overlapping sets. + const withSessionLocks = async (keys: string[], work: () => Promise): Promise => { + const sorted = Array.from(new Set(keys)).sort() + const acquire = async (idx: number): Promise => { + if (idx >= sorted.length) return work() + return sessionMutex.mutex(sorted[idx]!, () => acquire(idx + 1)) + } + + return acquire(0) + } + + // `signalStorage.{load,store}Session` resolve PN→LID internally; the cipher + // is constructed with this pre-resolved address so storage's `id` is already + // LID-form and its internal re-resolution becomes a no-op. Without this, + // the LID mapping could populate between mutex acquisition and the storage + // hook, locking one key and mutating another. + const resolveSignalAddress = async (jid: string): Promise => { + if (isPnUser(jid) || isHostedPnUser(jid)) { + const lidForPN = await lidMapping.getLIDForPN(jid) + if (lidForPN) { + return jidToSignalProtocolAddress(lidForPN) + } + } + + return jidToSignalProtocolAddress(jid) + } + + const resolveSessionKey = async (jid: string): Promise => (await resolveSignalAddress(jid)).toString() + + // Fall back to `transaction` for external auth stores that don't implement + // `isolatedTransaction`. The fallback reintroduces the visibility race for + // those stores but preserves backwards compatibility. + const runIsolated = (exec: () => Promise, key: string): Promise => + parsedKeys.isolatedTransaction ? parsedKeys.isolatedTransaction(exec) : parsedKeys.transaction(exec, key) + const ensureSenderKeyAndCreateSkdm = async (group: string, meId: string) => { const senderName = jidToSignalSenderKeyName(group, meId) const senderNameStr = senderName.toString() @@ -115,52 +160,48 @@ export function makeLibSignalRepository( }, item.groupId) }, async decryptMessage({ jid, type, ciphertext }) { - const addr = jidToSignalProtocolAddress(jid) + const addr = await resolveSignalAddress(jid) const session = new libsignal.SessionCipher(storage, addr) - - // Extract and save sender's identity key before decryption for identity change detection - if (type === 'pkmsg') { - const identityKey = extractIdentityFromPkmsg(ciphertext) - if (identityKey) { - const addrStr = addr.toString() - const identityChanged = await storage.saveIdentity(addrStr, identityKey) - if (identityChanged) { - logger.info({ jid, addr: addrStr }, 'identity key changed or new contact, session will be re-established') + const sessionKey = addr.toString() + + // `saveIdentity` may delete the session when the identity key + // changes, so it must hold the lock alongside the decrypt. + return sessionMutex.mutex(sessionKey, () => + runIsolated(async () => { + if (type === 'pkmsg') { + const identityKey = extractIdentityFromPkmsg(ciphertext) + if (identityKey) { + const identityChanged = await storage.saveIdentity(sessionKey, identityKey) + if (identityChanged) { + logger.info( + { jid, addr: sessionKey }, + 'identity key changed or new contact, session will be re-established' + ) + } + } } - } - } - async function doDecrypt() { - let result: Buffer - switch (type) { - case 'pkmsg': - result = await session.decryptPreKeyWhisperMessage(ciphertext) - break - case 'msg': - result = await session.decryptWhisperMessage(ciphertext) - break - } - - return result - } - - // If it's not a sync message, we need to ensure atomicity - // For regular messages, we use a transaction to ensure atomicity - return parsedKeys.transaction(async () => { - return await doDecrypt() - }, jid) + switch (type) { + case 'pkmsg': + return await session.decryptPreKeyWhisperMessage(ciphertext) + case 'msg': + return await session.decryptWhisperMessage(ciphertext) + } + }, sessionKey) + ) }, async encryptMessage({ jid, data }) { - const addr = jidToSignalProtocolAddress(jid) + const addr = await resolveSignalAddress(jid) const cipher = new libsignal.SessionCipher(storage, addr) - - // Use transaction to ensure atomicity - return parsedKeys.transaction(async () => { - const { type: sigType, body } = await cipher.encrypt(data) - const type = sigType === 3 ? 'pkmsg' : 'msg' - return { type, ciphertext: Buffer.from(body, 'binary') } - }, jid) + const sessionKey = addr.toString() + return sessionMutex.mutex(sessionKey, () => + runIsolated(async () => { + const { type: sigType, body } = await cipher.encrypt(data) + const type: 'pkmsg' | 'msg' = sigType === 3 ? 'pkmsg' : 'msg' + return { type, ciphertext: Buffer.from(body, 'binary') } + }, sessionKey) + ) }, async encryptGroupMessage({ group, meId, data }) { @@ -205,12 +246,15 @@ export function makeLibSignalRepository( async injectE2ESession({ jid, session }) { logger.trace({ jid }, 'injecting E2EE session') - const cipher = new libsignal.SessionBuilder(storage, jidToSignalProtocolAddress(jid)) - return parsedKeys.transaction(async () => { - // libsignal runtime accepts an absent prekey (initOutgoing checks `device.preKey && ...`) - // but the bundled .d.ts marks it required. - await cipher.initOutgoing(session as unknown as Parameters[0]) - }, jid) + const addr = await resolveSignalAddress(jid) + const cipher = new libsignal.SessionBuilder(storage, addr) + const sessionKey = addr.toString() + return sessionMutex.mutex(sessionKey, () => + runIsolated(async () => { + // `.d.ts` marks prekey required but runtime allows it absent. + await cipher.initOutgoing(session as unknown as Parameters[0]) + }, sessionKey) + ) }, jidToSignalProtocolAddress(jid) { return jidToSignalProtocolAddress(jid).toString() @@ -238,20 +282,32 @@ export function makeLibSignalRepository( } }, + async deleteSessionForJid(jid: string) { + const sessionKey = await resolveSessionKey(jid) + await sessionMutex.mutex(sessionKey, () => + runIsolated(async () => { + await auth.keys.set({ session: { [sessionKey]: null } }) + }, sessionKey) + ) + }, + async deleteSession(jids: string[]) { if (!jids.length) return - // Convert JIDs to signal addresses and prepare for bulk deletion const sessionUpdates: { [key: string]: null } = {} - jids.forEach(jid => { - const addr = jidToSignalProtocolAddress(jid) - sessionUpdates[addr.toString()] = null - }) + const sessionKeys = await Promise.all( + jids.map(async jid => { + const sessionKey = await resolveSessionKey(jid) + sessionUpdates[sessionKey] = null + return sessionKey + }) + ) - // Single transaction for all deletions - return parsedKeys.transaction(async () => { - await auth.keys.set({ session: sessionUpdates }) - }, `delete-${jids.length}-sessions`) + return withSessionLocks(sessionKeys, () => + runIsolated(async () => { + await auth.keys.set({ session: sessionUpdates }) + }, sessionKeys[0]!) + ) }, close() { @@ -324,82 +380,87 @@ export function makeLibSignalRepository( 'bulk device migration complete - all user devices processed' ) - // Single transaction for all migrations - return parsedKeys.transaction( - async (): Promise<{ migrated: number; skipped: number; total: number }> => { - // Prepare migration operations with addressing metadata - type MigrationOp = { - fromJid: string - toJid: string - pnUser: string - lidUser: string - deviceId: number - fromAddr: libsignal.ProtocolAddress - toAddr: libsignal.ProtocolAddress - } + // Prepare migration operations with addressing metadata + type MigrationOp = { + fromJid: string + toJid: string + pnUser: string + lidUser: string + deviceId: number + fromAddr: libsignal.ProtocolAddress + toAddr: libsignal.ProtocolAddress + } - const migrationOps: MigrationOp[] = deviceJids.map(jid => { - const lidWithDevice = transferDevice(jid, toJid) - const fromDecoded = jidDecode(jid)! - const toDecoded = jidDecode(lidWithDevice)! - - return { - fromJid: jid, - toJid: lidWithDevice, - pnUser: fromDecoded.user, - lidUser: toDecoded.user, - deviceId: fromDecoded.device || 0, - fromAddr: jidToSignalProtocolAddress(jid), - toAddr: jidToSignalProtocolAddress(lidWithDevice) - } - }) + const migrationOps: MigrationOp[] = deviceJids.map(jid => { + const lidWithDevice = transferDevice(jid, toJid) + const fromDecoded = jidDecode(jid)! + const toDecoded = jidDecode(lidWithDevice)! - const totalOps = migrationOps.length - let migratedCount = 0 + return { + fromJid: jid, + toJid: lidWithDevice, + pnUser: fromDecoded.user, + lidUser: toDecoded.user, + deviceId: fromDecoded.device || 0, + fromAddr: jidToSignalProtocolAddress(jid), + toAddr: jidToSignalProtocolAddress(lidWithDevice) + } + }) - // Bulk fetch PN sessions - already exist (verified during device discovery) - const pnAddrStrings = Array.from(new Set(migrationOps.map(op => op.fromAddr.toString()))) - const pnSessions = await parsedKeys.get('session', pnAddrStrings) + // Lock PN and LID per device so concurrent encrypt/decrypt can't + // race the read+rewrite below. + const mutexKeys = migrationOps.flatMap(op => [op.fromAddr.toString(), op.toAddr.toString()]) - // Prepare bulk session updates (PN → LID migration + deletion) - const sessionUpdates: { [key: string]: Uint8Array | null } = {} + return withSessionLocks(mutexKeys, () => + runIsolated( + async (): Promise<{ migrated: number; skipped: number; total: number }> => { + const totalOps = migrationOps.length + let migratedCount = 0 - for (const op of migrationOps) { - const pnAddrStr = op.fromAddr.toString() - const lidAddrStr = op.toAddr.toString() + // Bulk fetch PN sessions - already exist (verified during device discovery) + const pnAddrStrings = Array.from(new Set(migrationOps.map(op => op.fromAddr.toString()))) + const pnSessions = await parsedKeys.get('session', pnAddrStrings) - const pnSession = pnSessions[pnAddrStr] - if (pnSession) { - // Session exists (guaranteed from device discovery) - const fromSession = libsignal.SessionRecord.deserialize(pnSession) - if (fromSession.haveOpenSession()) { - // Queue for bulk update: copy to LID, delete from PN - sessionUpdates[lidAddrStr] = fromSession.serialize() - sessionUpdates[pnAddrStr] = null + // Prepare bulk session updates (PN → LID migration + deletion) + const sessionUpdates: { [key: string]: Uint8Array | null } = {} - migratedCount++ + for (const op of migrationOps) { + const pnAddrStr = op.fromAddr.toString() + const lidAddrStr = op.toAddr.toString() + + const pnSession = pnSessions[pnAddrStr] + if (pnSession) { + // Session exists (guaranteed from device discovery) + const fromSession = libsignal.SessionRecord.deserialize(pnSession) + if (fromSession.haveOpenSession()) { + // Queue for bulk update: copy to LID, delete from PN + sessionUpdates[lidAddrStr] = fromSession.serialize() + sessionUpdates[pnAddrStr] = null + + migratedCount++ + } } } - } - - // Single bulk session update for all migrations - if (Object.keys(sessionUpdates).length > 0) { - await parsedKeys.set({ session: sessionUpdates }) - logger.debug({ migratedSessions: migratedCount }, 'bulk session migration complete') - // Cache device-level migrations - for (const op of migrationOps) { - if (sessionUpdates[op.toAddr.toString()]) { - const deviceKey = `${op.pnUser}.${op.deviceId}` - migratedSessionCache.set(deviceKey, true) + // Single bulk session update for all migrations + if (Object.keys(sessionUpdates).length > 0) { + await parsedKeys.set({ session: sessionUpdates }) + logger.debug({ migratedSessions: migratedCount }, 'bulk session migration complete') + + // Cache device-level migrations + for (const op of migrationOps) { + if (sessionUpdates[op.toAddr.toString()]) { + const deviceKey = `${op.pnUser}.${op.deviceId}` + migratedSessionCache.set(deviceKey, true) + } } } - } - const skippedCount = totalOps - migratedCount - return { migrated: migratedCount, skipped: skippedCount, total: totalOps } - }, - `migrate-${deviceJids.length}-sessions-${jidDecode(toJid)?.user}` + const skippedCount = totalOps - migratedCount + return { migrated: migratedCount, skipped: skippedCount, total: totalOps } + }, + mutexKeys[0] ?? `migrate-${jidDecode(toJid)?.user}` + ) ) } } diff --git a/src/Socket/messages-recv.ts b/src/Socket/messages-recv.ts index 02d0a584091..ab6d94463bf 100644 --- a/src/Socket/messages-recv.ts +++ b/src/Socket/messages-recv.ts @@ -617,7 +617,6 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { if (enableAutoSessionRecreation && messageRetryManager && retryCount > 1) { try { // Check if we have a session with this JID - const sessionId = signalRepository.jidToSignalProtocolAddress(fromJid) const hasSession = await signalRepository.validateSession(fromJid) const result = messageRetryManager.shouldRecreateSession(fromJid, hasSession.exists) shouldRecreateSession = result.recreate @@ -625,8 +624,7 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { if (shouldRecreateSession) { logger.debug({ fromJid, retryCount, reason: recreateReason }, 'recreating session for retry') - // Delete existing session to force recreation - await authState.keys.set({ session: { [sessionId]: null } }) + await signalRepository.deleteSessionForJid(fromJid) forceIncludeKeys = true } } catch (error) { @@ -1383,7 +1381,7 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { { participant, stored: info.registrationId, received: receivedRegId }, 'reg id mismatch on retry without bundle, deleting session' ) - await authState.keys.set({ session: { [sessionId]: null } }) + await signalRepository.deleteSessionForJid(participant) } } } @@ -1397,7 +1395,7 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { } else if (retryCount > BASE_KEY_CHECK_RETRY) { if (messageRetryManager.hasSameBaseKey(sessionId, msgId, info.baseKey)) { logger.warn({ participant, retryCount }, 'base key collision on retry, forcing fresh session') - await authState.keys.set({ session: { [sessionId]: null } }) + await signalRepository.deleteSessionForJid(participant) } messageRetryManager.deleteBaseKey(sessionId, msgId) @@ -1417,7 +1415,7 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { if (shouldRecreateSession) { logger.debug({ participant, retryCount, reason: recreateReason }, 'recreating session for outgoing retry') - await authState.keys.set({ session: { [sessionId]: null } }) + await signalRepository.deleteSessionForJid(participant) } } catch (error) { logger.warn({ error, participant }, 'failed to check session recreation for outgoing retry') diff --git a/src/Types/Auth.ts b/src/Types/Auth.ts index 40e34f0b1ef..565cdea7a92 100644 --- a/src/Types/Auth.ts +++ b/src/Types/Auth.ts @@ -98,6 +98,14 @@ export type SignalKeyStore = { export type SignalKeyStoreWithTransaction = SignalKeyStore & { isInTransaction: () => boolean transaction(exec: () => Promise, key: string): Promise + /** + * Like `transaction` but never reuses a surrounding ALS context and + * always commits before returning. Caller is responsible for any + * serialisation. Optional for backwards compatibility — external auth + * stores that implement this interface directly may omit it; callers + * (e.g. libsignal.ts) fall back to `transaction` in that case. + */ + isolatedTransaction?(exec: () => Promise): Promise } export type TransactionCapabilityOptions = { diff --git a/src/Types/Signal.ts b/src/Types/Signal.ts index 3c1e65a7d5b..3b52edd09cb 100644 --- a/src/Types/Signal.ts +++ b/src/Types/Signal.ts @@ -76,6 +76,9 @@ export type SignalRepository = { migrateSession(fromJid: string, toJid: string): Promise<{ migrated: number; skipped: number; total: number }> validateSession(jid: string): Promise<{ exists: boolean; reason?: string }> deleteSession(jids: string[]): Promise + /** Delete the session for `jid` under the per-session lock so it can't + * race with concurrent encrypt/decrypt on the same recipient. */ + deleteSessionForJid(jid: string): Promise } // Optimized repository with pre-loaded LID mapping store diff --git a/src/Utils/auth-utils.ts b/src/Utils/auth-utils.ts index 6b75bf3827c..d952809abd3 100644 --- a/src/Utils/auth-utils.ts +++ b/src/Utils/auth-utils.ts @@ -340,6 +340,31 @@ export const addTransactionCapability = ( } finally { releaseTxMutexRef(key) } + }, + + /** + * Like `transaction` but never reuses a surrounding ALS context and + * always commits before returning. Required when the work's writes + * must be visible to a sibling caller before the local scope ends + * (e.g. a per-recipient lock that needs the next acquirer to read + * fresh state from disk). Caller is responsible for serialisation. + */ + isolatedTransaction: async (exec: () => Promise): Promise => { + const ctx: TransactionContext = { + cache: {}, + mutations: {}, + dbQueries: 0 + } + + try { + const result = await txStorage.run(ctx, exec) + await commitWithRetry(ctx.mutations) + logger.trace({ dbQueries: ctx.dbQueries }, 'isolated transaction completed') + return result + } catch (error) { + logger.error({ err: error }, 'isolated transaction failed') + throw error + } } } } diff --git a/src/__tests__/Signal/lid-mapping.test.ts b/src/__tests__/Signal/lid-mapping.test.ts index 68e49e9d592..58b90762464 100644 --- a/src/__tests__/Signal/lid-mapping.test.ts +++ b/src/__tests__/Signal/lid-mapping.test.ts @@ -9,6 +9,9 @@ const mockKeys: jest.Mocked = { get: jest.fn() as any, set: jest.fn(), transaction: jest.fn(async (work: () => any) => await work()) as any, + isolatedTransaction: jest.fn>( + async (work: () => any) => await work() + ) as any, isInTransaction: jest.fn() } const logger = P({ level: 'silent' }) diff --git a/src/__tests__/Utils/tc-token.test.ts b/src/__tests__/Utils/tc-token.test.ts index a107c8f8e4c..7795f9617a9 100644 --- a/src/__tests__/Utils/tc-token.test.ts +++ b/src/__tests__/Utils/tc-token.test.ts @@ -30,6 +30,9 @@ const createMockKeys = (): jest.Mocked => ({ get: jest.fn() as any, set: jest.fn(), transaction: jest.fn(async (work: () => any) => await work()) as any, + isolatedTransaction: jest.fn>( + async (work: () => any) => await work() + ) as any, isInTransaction: jest.fn() })