Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 179 additions & 118 deletions src/Signal/libsignal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { LIDMapping, SignalAuthState, SignalKeyStoreWithTransaction } from
import type { SignalRepositoryWithLIDStore } from '../Types/Signal'
import { generateSignalPubKey } from '../Utils'
import type { ILogger } from '../Utils/logger'
import { makeKeyedMutex } from '../Utils/make-mutex'
import {
isHostedLidUser,
isHostedPnUser,
Expand Down Expand Up @@ -57,11 +58,55 @@ export function makeLibSignalRepository(

const parsedKeys = auth.keys as SignalKeyStoreWithTransaction
const migratedSessionCache = new LRUCache<string, true>({
ttl: 3 * 24 * 60 * 60 * 1000, // 7 days
// Bounded to prevent unbounded growth for high-cardinality accounts.
max: 10_000,
ttl: 3 * 24 * 60 * 60 * 1000,
ttlAutopurge: true,
updateAgeOnGet: true
})

// `parsedKeys.transaction` skips its mutex when nested in another ALS
// transaction, so two outbound sends triggered from different recv
// handlers can hit the same Signal session in parallel. This mutex
// always acquires, keyed by the resolved (PN→LID) Signal address so
// PN- and LID-addressed calls converge.
const sessionMutex = makeKeyedMutex()

// Sorted to prevent deadlocks between callers asking for overlapping sets.
const withSessionLocks = async <T>(keys: string[], work: () => Promise<T>): Promise<T> => {
const sorted = Array.from(new Set(keys)).sort()
const acquire = async (idx: number): Promise<T> => {
if (idx >= sorted.length) return work()
return sessionMutex.mutex(sorted[idx]!, () => acquire(idx + 1))
}

return acquire(0)
}

// `signalStorage.{load,store}Session` resolve PN→LID internally; the cipher
// is constructed with this pre-resolved address so storage's `id` is already
// LID-form and its internal re-resolution becomes a no-op. Without this,
// the LID mapping could populate between mutex acquisition and the storage
// hook, locking one key and mutating another.
const resolveSignalAddress = async (jid: string): Promise<libsignal.ProtocolAddress> => {
if (isPnUser(jid) || isHostedPnUser(jid)) {
const lidForPN = await lidMapping.getLIDForPN(jid)
if (lidForPN) {
return jidToSignalProtocolAddress(lidForPN)
}
}

return jidToSignalProtocolAddress(jid)
}

const resolveSessionKey = async (jid: string): Promise<string> => (await resolveSignalAddress(jid)).toString()

// Fall back to `transaction` for external auth stores that don't implement
// `isolatedTransaction`. The fallback reintroduces the visibility race for
// those stores but preserves backwards compatibility.
const runIsolated = <T>(exec: () => Promise<T>, key: string): Promise<T> =>
parsedKeys.isolatedTransaction ? parsedKeys.isolatedTransaction(exec) : parsedKeys.transaction(exec, key)

const ensureSenderKeyAndCreateSkdm = async (group: string, meId: string) => {
const senderName = jidToSignalSenderKeyName(group, meId)
const senderNameStr = senderName.toString()
Expand Down Expand Up @@ -115,52 +160,48 @@ export function makeLibSignalRepository(
}, item.groupId)
},
async decryptMessage({ jid, type, ciphertext }) {
const addr = jidToSignalProtocolAddress(jid)
const addr = await resolveSignalAddress(jid)
const session = new libsignal.SessionCipher(storage, addr)

// Extract and save sender's identity key before decryption for identity change detection
if (type === 'pkmsg') {
const identityKey = extractIdentityFromPkmsg(ciphertext)
if (identityKey) {
const addrStr = addr.toString()
const identityChanged = await storage.saveIdentity(addrStr, identityKey)
if (identityChanged) {
logger.info({ jid, addr: addrStr }, 'identity key changed or new contact, session will be re-established')
const sessionKey = addr.toString()

// `saveIdentity` may delete the session when the identity key
// changes, so it must hold the lock alongside the decrypt.
return sessionMutex.mutex(sessionKey, () =>
runIsolated(async () => {
if (type === 'pkmsg') {
const identityKey = extractIdentityFromPkmsg(ciphertext)
if (identityKey) {
const identityChanged = await storage.saveIdentity(sessionKey, identityKey)
if (identityChanged) {
logger.info(
{ jid, addr: sessionKey },
'identity key changed or new contact, session will be re-established'
)
}
}
}
}
}

async function doDecrypt() {
let result: Buffer
switch (type) {
case 'pkmsg':
result = await session.decryptPreKeyWhisperMessage(ciphertext)
break
case 'msg':
result = await session.decryptWhisperMessage(ciphertext)
break
}

return result
}

// If it's not a sync message, we need to ensure atomicity
// For regular messages, we use a transaction to ensure atomicity
return parsedKeys.transaction(async () => {
return await doDecrypt()
}, jid)
switch (type) {
case 'pkmsg':
return await session.decryptPreKeyWhisperMessage(ciphertext)
case 'msg':
return await session.decryptWhisperMessage(ciphertext)
}
}, sessionKey)
)
},

async encryptMessage({ jid, data }) {
const addr = jidToSignalProtocolAddress(jid)
const addr = await resolveSignalAddress(jid)
const cipher = new libsignal.SessionCipher(storage, addr)

// Use transaction to ensure atomicity
return parsedKeys.transaction(async () => {
const { type: sigType, body } = await cipher.encrypt(data)
const type = sigType === 3 ? 'pkmsg' : 'msg'
return { type, ciphertext: Buffer.from(body, 'binary') }
}, jid)
const sessionKey = addr.toString()
return sessionMutex.mutex(sessionKey, () =>
runIsolated(async () => {
const { type: sigType, body } = await cipher.encrypt(data)
const type: 'pkmsg' | 'msg' = sigType === 3 ? 'pkmsg' : 'msg'
return { type, ciphertext: Buffer.from(body, 'binary') }
}, sessionKey)
)
},

async encryptGroupMessage({ group, meId, data }) {
Expand Down Expand Up @@ -205,12 +246,15 @@ export function makeLibSignalRepository(

async injectE2ESession({ jid, session }) {
logger.trace({ jid }, 'injecting E2EE session')
const cipher = new libsignal.SessionBuilder(storage, jidToSignalProtocolAddress(jid))
return parsedKeys.transaction(async () => {
// libsignal runtime accepts an absent prekey (initOutgoing checks `device.preKey && ...`)
// but the bundled .d.ts marks it required.
await cipher.initOutgoing(session as unknown as Parameters<typeof cipher.initOutgoing>[0])
}, jid)
const addr = await resolveSignalAddress(jid)
const cipher = new libsignal.SessionBuilder(storage, addr)
const sessionKey = addr.toString()
return sessionMutex.mutex(sessionKey, () =>
runIsolated(async () => {
// `.d.ts` marks prekey required but runtime allows it absent.
await cipher.initOutgoing(session as unknown as Parameters<typeof cipher.initOutgoing>[0])
}, sessionKey)
)
},
jidToSignalProtocolAddress(jid) {
return jidToSignalProtocolAddress(jid).toString()
Expand Down Expand Up @@ -238,20 +282,32 @@ export function makeLibSignalRepository(
}
},

async deleteSessionForJid(jid: string) {
const sessionKey = await resolveSessionKey(jid)
await sessionMutex.mutex(sessionKey, () =>
runIsolated(async () => {
await auth.keys.set({ session: { [sessionKey]: null } })
}, sessionKey)
)
},

async deleteSession(jids: string[]) {
if (!jids.length) return

// Convert JIDs to signal addresses and prepare for bulk deletion
const sessionUpdates: { [key: string]: null } = {}
jids.forEach(jid => {
const addr = jidToSignalProtocolAddress(jid)
sessionUpdates[addr.toString()] = null
})
const sessionKeys = await Promise.all(
jids.map(async jid => {
const sessionKey = await resolveSessionKey(jid)
sessionUpdates[sessionKey] = null
return sessionKey
})
)

// Single transaction for all deletions
return parsedKeys.transaction(async () => {
await auth.keys.set({ session: sessionUpdates })
}, `delete-${jids.length}-sessions`)
return withSessionLocks(sessionKeys, () =>
runIsolated(async () => {
await auth.keys.set({ session: sessionUpdates })
}, sessionKeys[0]!)
)
},

close() {
Expand Down Expand Up @@ -324,82 +380,87 @@ export function makeLibSignalRepository(
'bulk device migration complete - all user devices processed'
)

// Single transaction for all migrations
return parsedKeys.transaction(
async (): Promise<{ migrated: number; skipped: number; total: number }> => {
// Prepare migration operations with addressing metadata
type MigrationOp = {
fromJid: string
toJid: string
pnUser: string
lidUser: string
deviceId: number
fromAddr: libsignal.ProtocolAddress
toAddr: libsignal.ProtocolAddress
}
// Prepare migration operations with addressing metadata
type MigrationOp = {
fromJid: string
toJid: string
pnUser: string
lidUser: string
deviceId: number
fromAddr: libsignal.ProtocolAddress
toAddr: libsignal.ProtocolAddress
}

const migrationOps: MigrationOp[] = deviceJids.map(jid => {
const lidWithDevice = transferDevice(jid, toJid)
const fromDecoded = jidDecode(jid)!
const toDecoded = jidDecode(lidWithDevice)!

return {
fromJid: jid,
toJid: lidWithDevice,
pnUser: fromDecoded.user,
lidUser: toDecoded.user,
deviceId: fromDecoded.device || 0,
fromAddr: jidToSignalProtocolAddress(jid),
toAddr: jidToSignalProtocolAddress(lidWithDevice)
}
})
const migrationOps: MigrationOp[] = deviceJids.map(jid => {
const lidWithDevice = transferDevice(jid, toJid)
const fromDecoded = jidDecode(jid)!
const toDecoded = jidDecode(lidWithDevice)!

const totalOps = migrationOps.length
let migratedCount = 0
return {
fromJid: jid,
toJid: lidWithDevice,
pnUser: fromDecoded.user,
lidUser: toDecoded.user,
deviceId: fromDecoded.device || 0,
fromAddr: jidToSignalProtocolAddress(jid),
toAddr: jidToSignalProtocolAddress(lidWithDevice)
}
})

// Bulk fetch PN sessions - already exist (verified during device discovery)
const pnAddrStrings = Array.from(new Set(migrationOps.map(op => op.fromAddr.toString())))
const pnSessions = await parsedKeys.get('session', pnAddrStrings)
// Lock PN and LID per device so concurrent encrypt/decrypt can't
// race the read+rewrite below.
const mutexKeys = migrationOps.flatMap(op => [op.fromAddr.toString(), op.toAddr.toString()])

// Prepare bulk session updates (PN → LID migration + deletion)
const sessionUpdates: { [key: string]: Uint8Array | null } = {}
return withSessionLocks(mutexKeys, () =>
runIsolated(
async (): Promise<{ migrated: number; skipped: number; total: number }> => {
const totalOps = migrationOps.length
let migratedCount = 0

for (const op of migrationOps) {
const pnAddrStr = op.fromAddr.toString()
const lidAddrStr = op.toAddr.toString()
// Bulk fetch PN sessions - already exist (verified during device discovery)
const pnAddrStrings = Array.from(new Set(migrationOps.map(op => op.fromAddr.toString())))
const pnSessions = await parsedKeys.get('session', pnAddrStrings)

const pnSession = pnSessions[pnAddrStr]
if (pnSession) {
// Session exists (guaranteed from device discovery)
const fromSession = libsignal.SessionRecord.deserialize(pnSession)
if (fromSession.haveOpenSession()) {
// Queue for bulk update: copy to LID, delete from PN
sessionUpdates[lidAddrStr] = fromSession.serialize()
sessionUpdates[pnAddrStr] = null
// Prepare bulk session updates (PN → LID migration + deletion)
const sessionUpdates: { [key: string]: Uint8Array | null } = {}

migratedCount++
for (const op of migrationOps) {
const pnAddrStr = op.fromAddr.toString()
const lidAddrStr = op.toAddr.toString()

const pnSession = pnSessions[pnAddrStr]
if (pnSession) {
// Session exists (guaranteed from device discovery)
const fromSession = libsignal.SessionRecord.deserialize(pnSession)
if (fromSession.haveOpenSession()) {
// Queue for bulk update: copy to LID, delete from PN
sessionUpdates[lidAddrStr] = fromSession.serialize()
sessionUpdates[pnAddrStr] = null

migratedCount++
}
}
}
}

// Single bulk session update for all migrations
if (Object.keys(sessionUpdates).length > 0) {
await parsedKeys.set({ session: sessionUpdates })
logger.debug({ migratedSessions: migratedCount }, 'bulk session migration complete')

// Cache device-level migrations
for (const op of migrationOps) {
if (sessionUpdates[op.toAddr.toString()]) {
const deviceKey = `${op.pnUser}.${op.deviceId}`
migratedSessionCache.set(deviceKey, true)
// Single bulk session update for all migrations
if (Object.keys(sessionUpdates).length > 0) {
await parsedKeys.set({ session: sessionUpdates })
logger.debug({ migratedSessions: migratedCount }, 'bulk session migration complete')

// Cache device-level migrations
for (const op of migrationOps) {
if (sessionUpdates[op.toAddr.toString()]) {
const deviceKey = `${op.pnUser}.${op.deviceId}`
migratedSessionCache.set(deviceKey, true)
}
}
}
}

const skippedCount = totalOps - migratedCount
return { migrated: migratedCount, skipped: skippedCount, total: totalOps }
},
`migrate-${deviceJids.length}-sessions-${jidDecode(toJid)?.user}`
const skippedCount = totalOps - migratedCount
return { migrated: migratedCount, skipped: skippedCount, total: totalOps }
},
mutexKeys[0] ?? `migrate-${jidDecode(toJid)?.user}`
)
)
}
}
Expand Down
Loading
Loading