Skip to content

Commit 35ce15d

Browse files
authored
fix: make handshake abortable (#442)
To allow doing things like having a single `AbortSignal` that can be used as a timeout for incoming connection establishment, allow passing it as an option to the `ConnectionEncrypter` `secureOutbound` and `secureInbound` methods. Previously we'd wrap the stream to be secured in an `AbortableSource`, however this has some [serious performance implications](ChainSafe/js-libp2p-gossipsub#361) and it's generally better to just use a signal to cancel an ongoing operation instead of racing every chunk that comes out of the source.
1 parent dadcd1c commit 35ce15d

File tree

3 files changed

+56
-21
lines changed

3 files changed

+56
-21
lines changed

src/noise.ts

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { unmarshalPrivateKey } from '@libp2p/crypto/keys'
2-
import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId } from '@libp2p/interface'
2+
import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId, type AbortOptions } from '@libp2p/interface'
33
import { peerIdFromKeys } from '@libp2p/peer-id'
44
import { decode } from 'it-length-prefixed'
55
import { lpStream, type LengthPrefixedStream } from 'it-length-prefixed-stream'
@@ -72,10 +72,10 @@ export class Noise implements INoiseConnection {
7272
* @param connection - streaming iterable duplex that will be encrypted
7373
* @param remotePeer - PeerId of the remote peer. Used to validate the integrity of the remote peer.
7474
*/
75-
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
75+
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise<SecuredConnection<Stream, NoiseExtensions>>
7676
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
7777
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (...args: any[]): Promise<SecuredConnection<Stream, NoiseExtensions>> {
78-
const { localPeer, connection, remotePeer } = this.parseArgs<Stream>(args)
78+
const { localPeer, connection, remotePeer, signal } = this.parseArgs<Stream>(args)
7979

8080
const wrappedConnection = lpStream(
8181
connection,
@@ -96,7 +96,9 @@ export class Noise implements INoiseConnection {
9696
const handshake = await this.performHandshakeInitiator(
9797
wrappedConnection,
9898
privateKey,
99-
remoteIdentityKey
99+
remoteIdentityKey, {
100+
signal
101+
}
100102
)
101103
const conn = await this.createSecureConnection(wrappedConnection, handshake)
102104

@@ -117,10 +119,10 @@ export class Noise implements INoiseConnection {
117119
* @param connection - streaming iterable duplex that will be encrypted.
118120
* @param remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades.
119121
*/
120-
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
122+
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise<SecuredConnection<Stream, NoiseExtensions>>
121123
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
122124
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (...args: any[]): Promise<SecuredConnection<Stream, NoiseExtensions>> {
123-
const { localPeer, connection, remotePeer } = this.parseArgs<Stream>(args)
125+
const { localPeer, connection, remotePeer, signal } = this.parseArgs<Stream>(args)
124126

125127
const wrappedConnection = lpStream(
126128
connection,
@@ -141,7 +143,9 @@ export class Noise implements INoiseConnection {
141143
const handshake = await this.performHandshakeResponder(
142144
wrappedConnection,
143145
privateKey,
144-
remoteIdentityKey
146+
remoteIdentityKey, {
147+
signal
148+
}
145149
)
146150
const conn = await this.createSecureConnection(wrappedConnection, handshake)
147151

@@ -162,7 +166,8 @@ export class Noise implements INoiseConnection {
162166
connection: LengthPrefixedStream,
163167
// TODO: pass private key in noise constructor via Components
164168
privateKey: PrivateKey,
165-
remoteIdentityKey?: Uint8Array | Uint8ArrayList
169+
remoteIdentityKey?: Uint8Array | Uint8ArrayList,
170+
options?: AbortOptions
166171
): Promise<HandshakeResult> {
167172
let result: HandshakeResult
168173
try {
@@ -175,7 +180,7 @@ export class Noise implements INoiseConnection {
175180
prologue: this.prologue,
176181
s: this.staticKey,
177182
extensions: this.extensions
178-
})
183+
}, options)
179184
this.metrics?.xxHandshakeSuccesses.increment()
180185
} catch (e: unknown) {
181186
this.metrics?.xxHandshakeErrors.increment()
@@ -192,7 +197,8 @@ export class Noise implements INoiseConnection {
192197
connection: LengthPrefixedStream,
193198
// TODO: pass private key in noise constructor via Components
194199
privateKey: PrivateKey,
195-
remoteIdentityKey?: Uint8Array | Uint8ArrayList
200+
remoteIdentityKey?: Uint8Array | Uint8ArrayList,
201+
options?: AbortOptions
196202
): Promise<HandshakeResult> {
197203
let result: HandshakeResult
198204
try {
@@ -205,7 +211,7 @@ export class Noise implements INoiseConnection {
205211
prologue: this.prologue,
206212
s: this.staticKey,
207213
extensions: this.extensions
208-
})
214+
}, options)
209215
this.metrics?.xxHandshakeSuccesses.increment()
210216
} catch (e: unknown) {
211217
this.metrics?.xxHandshakeErrors.increment()
@@ -241,7 +247,7 @@ export class Noise implements INoiseConnection {
241247
* TODO: remove this after `[email protected]` is released and only support the
242248
* newer style
243249
*/
244-
private parseArgs <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId } {
250+
private parseArgs <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId, signal?: AbortSignal } {
245251
// if the first argument is a peer id, we're using the [email protected] style
246252
if (isPeerId(args[0])) {
247253
return {
@@ -256,7 +262,8 @@ export class Noise implements INoiseConnection {
256262
return {
257263
localPeer: this.components.peerId,
258264
connection: args[0],
259-
remotePeer: args[1]
265+
remotePeer: args[1]?.remotePeer,
266+
signal: args[1]?.signal
260267
}
261268
}
262269
}

src/performHandshake.ts

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ import {
88
import { ZEROLEN, XXHandshakeState } from './protocol.js'
99
import { createHandshakePayload, decodeHandshakePayload } from './utils.js'
1010
import type { HandshakeResult, HandshakeParams } from './types.js'
11+
import type { AbortOptions } from '@libp2p/interface'
1112

12-
export async function performHandshakeInitiator (init: HandshakeParams): Promise<HandshakeResult> {
13+
export async function performHandshakeInitiator (init: HandshakeParams, options?: AbortOptions): Promise<HandshakeResult> {
1314
const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init
1415

1516
const payload = await createHandshakePayload(privateKey, s.publicKey, extensions)
@@ -23,12 +24,12 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise
2324

2425
logLocalStaticKeys(xx.s, log)
2526
log.trace('Stage 0 - Initiator starting to send first message.')
26-
await connection.write(xx.writeMessageA(ZEROLEN))
27+
await connection.write(xx.writeMessageA(ZEROLEN), options)
2728
log.trace('Stage 0 - Initiator finished sending first message.')
2829
logLocalEphemeralKeys(xx.e, log)
2930

3031
log.trace('Stage 1 - Initiator waiting to receive first message from responder...')
31-
const plaintext = xx.readMessageB(await connection.read())
32+
const plaintext = xx.readMessageB(await connection.read(options))
3233
log.trace('Stage 1 - Initiator received the message.')
3334
logRemoteEphemeralKey(xx.re, log)
3435
logRemoteStaticKey(xx.rs, log)
@@ -38,7 +39,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise
3839
log.trace('All good with the signature!')
3940

4041
log.trace('Stage 2 - Initiator sending third handshake message.')
41-
await connection.write(xx.writeMessageC(payload))
42+
await connection.write(xx.writeMessageC(payload), options)
4243
log.trace('Stage 2 - Initiator sent message with signed payload.')
4344

4445
const [cs1, cs2] = xx.ss.split()
@@ -51,7 +52,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise
5152
}
5253
}
5354

54-
export async function performHandshakeResponder (init: HandshakeParams): Promise<HandshakeResult> {
55+
export async function performHandshakeResponder (init: HandshakeParams, options?: AbortOptions): Promise<HandshakeResult> {
5556
const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init
5657

5758
const payload = await createHandshakePayload(privateKey, s.publicKey, extensions)
@@ -65,17 +66,17 @@ export async function performHandshakeResponder (init: HandshakeParams): Promise
6566

6667
logLocalStaticKeys(xx.s, log)
6768
log.trace('Stage 0 - Responder waiting to receive first message.')
68-
xx.readMessageA(await connection.read())
69+
xx.readMessageA(await connection.read(options))
6970
log.trace('Stage 0 - Responder received first message.')
7071
logRemoteEphemeralKey(xx.re, log)
7172

7273
log.trace('Stage 1 - Responder sending out first message with signed payload and static key.')
73-
await connection.write(xx.writeMessageB(payload))
74+
await connection.write(xx.writeMessageB(payload), options)
7475
log.trace('Stage 1 - Responder sent the second handshake message with signed payload.')
7576
logLocalEphemeralKeys(xx.e, log)
7677

7778
log.trace('Stage 2 - Responder waiting for third handshake message...')
78-
const plaintext = xx.readMessageC(await connection.read())
79+
const plaintext = xx.readMessageC(await connection.read(options))
7980
log.trace('Stage 2 - Responder received the message, finished handshake.')
8081
const receivedPayload = await decodeHandshakePayload(plaintext, xx.rs, remoteIdentityKey)
8182

test/noise.spec.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,31 @@ describe('Noise', () => {
179179
assert(false, err.message)
180180
}
181181
})
182+
183+
it('should abort noise handshake', async () => {
184+
const abortController = new AbortController()
185+
abortController.abort()
186+
187+
const noiseInit = new Noise({
188+
peerId: localPeer,
189+
logger: defaultLogger()
190+
}, { staticNoiseKey: undefined, extensions: undefined })
191+
const noiseResp = new Noise({
192+
peerId: remotePeer,
193+
logger: defaultLogger()
194+
}, { staticNoiseKey: undefined, extensions: undefined })
195+
196+
const [inboundConnection, outboundConnection] = duplexPair<Uint8Array | Uint8ArrayList>()
197+
198+
await expect(Promise.all([
199+
noiseInit.secureOutbound(outboundConnection, {
200+
remotePeer,
201+
signal: abortController.signal
202+
}),
203+
noiseResp.secureInbound(inboundConnection, {
204+
remotePeer: localPeer
205+
})
206+
])).to.eventually.be.rejected
207+
.with.property('name', 'AbortError')
208+
})
182209
})

0 commit comments

Comments
 (0)